From 90c7548370d75196b9cbe851cb8345abcf8825bc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 28 Nov 2023 16:27:10 +0100 Subject: [PATCH] PT distributed training, better DDP wrapping Fix #1451 --- returnn/torch/distributed.py | 67 ------------------------------------ returnn/torch/engine.py | 51 ++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 76 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index d08f7ee2d1..aa92ec53e4 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -122,70 +122,3 @@ def _find_tensors(obj): if isinstance(obj, dict): return itertools.chain(*map(_find_tensors, obj.values())) return [] - - -@contextmanager -def ddp_train_forward_ctx(pt_model: DistributedDataParallel): - """ - the original (unwrapped) module is passed to the train step, therefore here we set up the right context - as what DistributedDataParallel.forward does internally - """ - if torch.is_grad_enabled() and pt_model.require_backward_grad_sync: - assert pt_model.logger is not None - pt_model.logger.set_runtime_stats_and_log() - pt_model.num_iterations += 1 - pt_model.reducer.prepare_for_forward() - - with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): - if torch.is_grad_enabled() and pt_model.require_backward_grad_sync: - assert pt_model.logger is not None - pt_model.logger.set_runtime_stats_and_log() - pt_model.num_iterations += 1 - pt_model.reducer.prepare_for_forward() - - work = Join.notify_join_context(pt_model) - if work: - # noinspection PyProtectedMember - pt_model.reducer._set_forward_pass_work_handle(work, pt_model._divide_by_initial_world_size) - - # noinspection PyProtectedMember - if torch.is_grad_enabled() and pt_model.reducer._rebuild_buckets(): - pt_model._has_rebuilt_buckets = True - - # noinspection PyProtectedMember - if pt_model._check_sync_bufs_pre_fwd(): - # noinspection PyProtectedMember - pt_model._sync_buffers() - - # noinspection PyProtectedMember - if pt_model._join_config.enable: - # Notify joined ranks whether they should sync in backwards pass or not. - # noinspection PyProtectedMember - pt_model._check_global_requires_backward_grad_sync(is_joined_rank=False) - - # noinspection PyProtectedMember - with pt_model._inside_ddp_forward(): - yield - - # noinspection PyProtectedMember - if pt_model._check_sync_bufs_post_fwd(): - # noinspection PyProtectedMember - pt_model._sync_buffers() - - if torch.is_grad_enabled() and pt_model.require_backward_grad_sync: - pt_model.require_forward_param_sync = True - # We'll return the output object verbatim since it is a freeform - # object. We need to find any tensors in this object, though, - # because we need to figure out which parameters were used during - # this forward pass, to ensure we short circuit reduction for any - # unused parameters. Only if `find_unused_parameters` is set. - if pt_model.find_unused_parameters and not pt_model.static_graph: - # Do not need to populate this for static graph. - train_ctx = rf.get_run_ctx() - loss = list(train_ctx.losses.values())[0].loss.raw_tensor - # noinspection PyProtectedMember - pt_model.reducer.prepare_for_backward(list(_find_tensors(loss))) - else: - pt_model.reducer.prepare_for_backward([]) - else: - pt_model.require_forward_param_sync = False diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 9133d92a6e..0889cfb36a 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -196,7 +196,9 @@ def init_train_from_config( # wrap the model use torch distributed class self._ddp_pt_model = self._torch_distributed_class( - self._pt_model, device_ids=get_device_ids(), **self._torch_distributed_options + module=_WrappedModuleRunStep(module=self._pt_model, engine=self), + device_ids=get_device_ids(), + **self._torch_distributed_options, ) self._updater = Updater( config=self.config, network=self._pt_model, device=self._device, initial_learning_rate=self.learning_rate @@ -542,10 +544,17 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader: **loader_opts, ) - def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool): + def _run_step( + self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool, _inside_wrapped: bool = False + ): """ :param extern_data: model inputs for the step + :return: Nothing, all outputs are written to the run context (:func:`rf.get_run_ctx`). """ + if self._ddp_pt_model is not None and not _inside_wrapped: + self._ddp_pt_model(extern_data=extern_data, train_flag=train_flag, train_func=train_func) + return + if train_func: assert self._train_step_func is not None rf.init_train_step_run_ctx(train_flag=train_flag, step=self.global_train_step) @@ -555,15 +564,9 @@ def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_ expected_outputs=self._forward_step_expected_outputs, step=self.global_train_step ) - from returnn.torch.distributed import ddp_train_forward_ctx - with autocast( device_type=self._device.split(":")[0], dtype=self._autocast_dtype - ) if self._use_autocast else nullcontext(), ddp_train_forward_ctx(pt_model=self._ddp_pt_model) if isinstance( - self._ddp_pt_model, DistributedDataParallel - ) else nullcontext(), rf.set_default_device_ctx( - self._device - ): + ) if self._use_autocast else nullcontext(), rf.set_default_device_ctx(self._device): sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None} if train_func: self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw) @@ -1038,3 +1041,33 @@ def _data_loader_worker_init_func(worker_id: int): if sys.platform == "linux": with open("/proc/self/comm", "w") as f: f.write(f"TDL worker {worker_id}") + + +class _WrappedModuleRunStep(torch.nn.Module): + """ + Wraps any Torch module (pure or RF), + and the `forward` function calls the run step function (train_step or forward_step) + and returns all produced raw tensors via the run context (losses or outputs) (:func:`rf.get_run_ctx`). + This is useful to use the API of DistributedDataParallel and potentially other PyTorch modules. + """ + + def __init__(self, *, module: torch.nn.Module, engine: Engine): + super().__init__() + self.module = module + self.engine = engine + + def forward(self, *args, **kwargs): + # noinspection PyProtectedMember + self.engine._run_step(*args, **kwargs, _inside_wrapped=True) + + # Note that we don't directly use the returned raw values here, + # but the PyTorch API might, + # e.g. DistributedDataParallel checks it to collect all gradients. + # We will use rf.get_run_ctx() later again in the engine to access these values. + res = {} + ctx = rf.get_run_ctx() + for name, out in ctx.outputs.data.items(): + res["output/" + name] = out.raw_tensor + for name, loss in ctx.losses.items(): + res["loss/" + name] = loss.loss.raw_tensor + return res