diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 031804937..79ba859ca 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -119,6 +119,8 @@ def _wait_impl(self) -> W: """ ret = self.wait_function.apply(self.pg, self, self.dummy_tensor) + if isinstance(ret, torch.Tensor): + ret.record_stream(torch.get_device_module(ret.device).current_stream()) self.req = None self.tensor = None return ret diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index d5ba9e774..c49c6a05b 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -119,6 +119,10 @@ def _test_sharding( @skip_if_asan_class class ConstructParameterShardingAndShardTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) # pyre-fixme[56] @given( per_param_sharding=st.sampled_from( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 05f6b6235..9c4e30326 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -152,6 +152,8 @@ def forward(self, x): fqn="test_module", args=[], context=TrainPipelineContext(), + default_stream=MagicMock(), + dist_stream=MagicMock(), ) # self-check - we want the state dict be the same between vanilla model and "rewritten model" self.assertDictEqual(model.state_dict(), rewritten_model.state_dict()) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 5f63fbc74..e9189bd3f 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -536,6 +536,7 @@ def _pipeline_model( model=self._model, context=context, dist_stream=self._data_dist_stream, + default_stream=torch.get_device_module(self._device).current_stream(), batch=batch, apply_jit=self._apply_jit, pipelined_forward=pipelined_forward, @@ -576,7 +577,7 @@ def copy_batch_to_gpu( StopIteration: if the dataloader iterator is exhausted; unless `self._execute_all_batches=True`, then returns None. """ - context = None + context = self._create_context() with record_function(f"## copy_batch_to_gpu {self._next_index} ##"): with self._stream_context(self._memcpy_stream): batch = self._next_batch(dataloader_iter) @@ -584,7 +585,6 @@ def copy_batch_to_gpu( batch = _to_device(batch, self._device, non_blocking=True) elif not self._execute_all_batches: raise StopIteration - context = self._create_context() return batch, context def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: @@ -747,25 +747,9 @@ def __init__( ) self._start_batch = start_batch self._stash_gradients = stash_gradients + logger.debug(f"Starting semi-sync run at batch: {self._start_batch}") - # use two data streams to support two concurrent batches - self._embedding_odd_stream: Optional[torch.Stream] = ( - (torch.get_device_module(self._device).Stream(priority=0)) - if device.type in ["cuda", "mtia"] - else None - ) - self._embedding_even_stream: Optional[torch.Stream] = ( - (torch.get_device_module(self._device).Stream(priority=0)) - if device.type in ["cuda", "mtia"] - else None - ) - self._overarch_stream: Optional[torch.Stream] = ( - (torch.get_device_module(self._device).Stream(priority=-1)) - if device.type in ["cuda", "mtia"] - else None - ) - self._embedding_odd_streams: List[Optional[torch.Stream]] = [] - self._embedding_even_streams: List[Optional[torch.Stream]] = [] + self._embedding_streams: List[Optional[torch.Stream]] = [] self._gradients: Dict[str, torch.Tensor] = {} def _grad_swap(self) -> None: @@ -778,12 +762,7 @@ def _grad_swap(self) -> None: def _init_embedding_streams(self) -> None: for _ in self._pipelined_modules: - self._embedding_odd_streams.append( - (torch.get_device_module(self._device).Stream(priority=0)) - if self._device.type in ["cuda", "mtia"] - else None - ) - self._embedding_even_streams.append( + self._embedding_streams.append( (torch.get_device_module(self._device).Stream(priority=0)) if self._device.type in ["cuda", "mtia"] else None @@ -839,13 +818,9 @@ def is_semi_sync(self) -> bool: return self.contexts[0].index >= self._start_batch return False - def _mlp_optimizer_step(self) -> None: + def _mlp_optimizer_step(self, current_batch: int) -> None: # special case: not all optimizers support optim.step() on null gradidents - if ( - len(self.batches) >= 1 - and self.contexts[0].index == self._start_batch - and self._stash_gradients - ): + if current_batch == self._start_batch and self._stash_gradients: return self._optimizer.step() @@ -860,42 +835,56 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self.contexts[2], ) - losses, output = self._mlp_forward(cast(In, self.batches[0]), self.contexts[0]) + batch, context = self.batches[0], self.contexts[0] + is_semi_sync = context.index is not None and context.index >= self._start_batch + iteration: int = context.index or 0 + losses, output = self._mlp_forward(cast(In, batch), context) + + # After this point, pipelined preproc/module forward won't be called + # so we can advance their contexts to the context of the next batch already + # and also pop batch and context from self.batches and self.contexts + self.dequeue_batch() + + # batch no longer needed - delete to free up memory + del batch + + # cached preproc fwd results no longer needed - delete to free up memory + del context.preproc_fwd_results # batch i+3 self.enqueue_batch(dataloader_iter) - if len(self.batches) >= 2 and self.is_semi_sync(): + if len(self.batches) >= 1 and is_semi_sync: # pyre-ignore [6] - self.start_embedding_lookup(self.batches[1], self.contexts[1]) + self.start_embedding_lookup(self.batches[0], self.contexts[0]) - if len(self.batches) >= 3: - self.wait_sparse_data_dist(self.contexts[2]) + if len(self.batches) >= 2: + self.wait_sparse_data_dist(self.contexts[1]) if self._model.training: - with record_function(f"## backward {self.contexts[0].index} ##"): + with record_function(f"## backward {iteration} ##"): torch.sum(losses, dim=0).backward() - # pyre-ignore [6] - self.embedding_backward(self.contexts[0]) + with record_function(f"## emb_backward {iteration} ##"): + # pyre-ignore [6] + self.embedding_backward(context) - with record_function( - f"## optimizer {cast(int, self.contexts[0].index) - 1} ##" - ): - if self.is_semi_sync() and self._stash_gradients: + del context # context is no longer needed, deleting to free up memory + + with record_function(f"## optimizer {iteration - 1} ##"): + if is_semi_sync and self._stash_gradients: self._grad_swap() - self._mlp_optimizer_step() + self._mlp_optimizer_step(iteration) - with record_function( - f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##" - ): + with record_function(f"## zero_grad {iteration - 1} ##"): self._optimizer.zero_grad() + else: + del context - if len(self.batches) >= 2 and not self.is_semi_sync(): + if len(self.batches) >= 1 and not is_semi_sync: torch.cuda.synchronize() # needed to avoid race condition # pyre-ignore [6] - self.start_embedding_lookup(self.batches[1], self.contexts[1]) + self.start_embedding_lookup(self.batches[0], self.contexts[0]) - self.dequeue_batch() return output def _mlp_forward( @@ -909,14 +898,9 @@ def _mlp_forward( def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: default_stream = torch.get_device_module(self._device).current_stream() - streams = ( - self._embedding_even_streams - if cast(int, context.index) % 2 == 0 - else self._embedding_odd_streams - ) assert len(context.embedding_features) == len(context.embedding_tensors) for stream, emb_tensors, embedding_features, detached_emb_tensors in zip( - streams, + self._embedding_streams, context.embedding_tensors, context.embedding_features, context.detached_embedding_tensors, @@ -939,7 +923,9 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: embs_to_backprop.append(tensor) grads_to_use.append(grad) else: - if isinstance(features, Iterable): + if isinstance(features, str): + invalid_features.append(features) + elif isinstance(features, Iterable): invalid_features.extend(features) else: invalid_features.append(features) @@ -1012,13 +998,14 @@ def start_embedding_lookup( batch, context, torch.get_device_module(self._device).current_stream() ) for i, module in enumerate(self._pipelined_modules): - stream = ( - self._embedding_even_streams[i] - if cast(int, context.index) % 2 == 0 - else self._embedding_odd_streams[i] - ) + stream = self._embedding_streams[i] with self._stream_context(stream): - _start_embedding_lookup(module, context, stream) + _start_embedding_lookup( + module, + context, + source_stream=self._data_dist_stream, + target_stream=stream, + ) event = torch.get_device_module(self._device).Event() event.record() context.events.append(event) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index c69fcabea..efc772a90 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -21,6 +21,7 @@ cast, Dict, Generic, + Iterable, Iterator, List, Optional, @@ -33,6 +34,7 @@ import torch from torch import distributed as dist +from torchrec.distributed.types import LazyAwaitable if not torch._running_with_deploy(): from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 @@ -135,7 +137,9 @@ class PrefetchTrainPipelineContext(TrainPipelineContext): @dataclass class EmbeddingTrainPipelineContext(TrainPipelineContext): - embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict) + embedding_a2a_requests: Dict[str, LazyAwaitable[Multistreamable]] = field( + default_factory=dict + ) embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list) detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) @@ -229,6 +233,21 @@ def _build_args_kwargs( return args, kwargs +def recursive_record_stream( + # pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any` + res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]], + stream: torch.Stream, +) -> None: + if isinstance(res, (torch.Tensor, Pipelineable)): + res.record_stream(stream) + elif isinstance(res, (list, tuple)): + for v in res: + recursive_record_stream(v, stream) + elif isinstance(res, dict): + for v in res.values(): + recursive_record_stream(v, stream) + + class PipelinedPreproc(torch.nn.Module): """ Wrapper around preproc module found during model graph traversal for sparse data dist @@ -258,12 +277,25 @@ def __init__( fqn: str, args: List[ArgInfo], context: TrainPipelineContext, + # TODO: make streams non-optional - skipping now to avoid ripple effect + default_stream: Optional[torch.Stream], + dist_stream: Optional[torch.Stream], ) -> None: super().__init__() self._preproc_module = preproc_module self._fqn = fqn self._args = args self._context = context + self._default_stream = default_stream + self._dist_stream = dist_stream + if not default_stream: + logger.warning( + f"Preproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" + ) + if not dist_stream: + logger.warning( + f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" + ) @property def preproc_module(self) -> torch.nn.Module: @@ -308,9 +340,35 @@ def forward(self, *input, **kwargs) -> Any: args, kwargs = _build_args_kwargs(input[0], self._args) with record_function(f"## sdd_input_preproc {self._context.index} ##"): - res = self._preproc_module(*args, **kwargs) - # Cache results, only during _start_data_dist - self._context.preproc_fwd_results[self._fqn] = res + # should be no-op as we call this in dist stream + # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream + with torch.cuda.stream(self._dist_stream): + res = self._preproc_module(*args, **kwargs) + + # Ensure preproc modules output is safe to use from default stream later + if self._default_stream and self._dist_stream: + self._default_stream.wait_stream(self._dist_stream) + + if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)): + # Result from module forward might be a complex type such as + # Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]] + # In this case, we need to first iterate over each element of tuple + # and call record_stream on first item as KJT is Pipelineable + # for the second item (Dict), we iterate over the values and call + # record_stream accordingly. + + # pyre-ignore[6] + recursive_record_stream(res, self._default_stream) + elif self._context.index == 0: + logger.warning( + f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" + ) + + # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream + with torch.cuda.stream(self._default_stream): + # Cache results, only during _start_data_dist + self._context.preproc_fwd_results[self._fqn] = res + return res @property @@ -382,13 +440,16 @@ def load_state_dict( return self._preproc_module.load_state_dict(state_dict, strict=strict) -class BaseForward: +TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) + + +class BaseForward(Generic[TForwardContext]): def __init__( self, name: str, args: List[ArgInfo], module: ShardedModule, - context: TrainPipelineContext, + context: TForwardContext, stream: Optional[torch.Stream] = None, ) -> None: self._name = name @@ -406,14 +467,14 @@ def name(self) -> str: def args(self) -> List[ArgInfo]: return self._args - def set_context(self, context: TrainPipelineContext) -> None: + def set_context(self, context: TForwardContext) -> None: self._context = context - def get_context(self) -> TrainPipelineContext: + def get_context(self) -> TForwardContext: return self._context -class PipelinedForward(BaseForward): +class PipelinedForward(BaseForward[TrainPipelineContext]): # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: assert ( @@ -446,21 +507,20 @@ def __call__(self, *input, **kwargs) -> Awaitable: return self._module.compute_and_output_dist(ctx, data) -class EmbeddingPipelinedForward(BaseForward): +class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]): # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: assert ( - self._name - # pyre-ignore [16] - in self._context.embedding_a2a_requests + self._name in self._context.embedding_a2a_requests ), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()" ctx = self._context.module_contexts.pop(self._name) + cur_stream = torch.get_device_module(self._device).current_stream() + if self._stream is not None: torch.get_device_module(self._device).current_stream().wait_stream( self._stream ) - cur_stream = torch.get_device_module(self._device).current_stream() ctx.record_stream(cur_stream) awaitable = self._context.embedding_a2a_requests.pop(self._name) embeddings = awaitable.wait() # trigger awaitable manually for type checking @@ -475,14 +535,12 @@ def __call__(self, *input, **kwargs) -> Awaitable: jt._values = detached_tensor tensors.append(tensor) detached_tensors.append(detached_tensor) - # pyre-ignore [16] self._context.embedding_tensors.append(tensors) - # pyre-ignore [16] self._context.embedding_features.append(list(embeddings.keys())) - # pyre-ignore [16] self._context.detached_embedding_tensors.append(detached_tensors) else: assert isinstance(embeddings, KeyedTensor) + embeddings.record_stream(cur_stream) tensor = embeddings.values() detached_tensor = tensor.detach().requires_grad_() detached_tensor.retain_grad() @@ -502,7 +560,7 @@ def __call__(self, *input, **kwargs) -> Awaitable: return LazyNoWait(embeddings) -class PrefetchPipelinedForward(BaseForward): +class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): def __init__( self, name: str, @@ -522,12 +580,9 @@ def __init__( # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: assert ( - self._name - # pyre-ignore [16] - in self._context.module_input_post_prefetch + self._name in self._context.module_input_post_prefetch ), "Invalid PrefetchPipelinedForward usage, please do not directly call model.forward()" data = self._context.module_input_post_prefetch.pop(self._name) - # pyre-ignore [16] ctx = self._context.module_contexts_post_prefetch.pop(self._name) # Make sure that both result of input_dist and context @@ -703,15 +758,18 @@ def _start_data_dist( def _start_embedding_lookup( module: ShardedModule, context: EmbeddingTrainPipelineContext, - stream: Optional[torch.Stream], + source_stream: Optional[torch.Stream], + target_stream: Optional[torch.Stream], ) -> None: - kjt = context.input_dist_tensors_requests[module.forward.name].wait() module_context = context.module_contexts[module.forward.name] - if stream: - kjt.record_stream(stream) - module_context.record_stream(stream) + # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream + with torch.cuda.stream(source_stream): + kjt = context.input_dist_tensors_requests[module.forward.name].wait() + + if target_stream is not None: + kjt.record_stream(target_stream) + module_context.record_stream(target_stream) a2a_awaitable = module.compute_and_output_dist(module_context, kjt) - # pyre-ignore[6] context.embedding_a2a_requests[module.forward.name] = a2a_awaitable @@ -828,6 +886,8 @@ def _get_node_args_helper( # Add `None` constants to arg info only for preproc modules # Defaults to False for backward compatibility for_preproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, ) -> Tuple[List[ArgInfo], int]: """ Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. @@ -1002,6 +1062,8 @@ def _get_node_args_helper( context, pipeline_preproc, True, + default_stream=default_stream, + dist_stream=dist_stream, ) if num_found_safe_preproc_args == total_num_args: logger.info( @@ -1017,6 +1079,8 @@ def _get_node_args_helper( preproc_module_fqn, preproc_args, context, + default_stream=default_stream, + dist_stream=dist_stream, ) # module swap @@ -1047,6 +1111,8 @@ def _get_node_args( context: TrainPipelineContext, pipeline_preproc: bool, for_preproc_module: bool = False, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, ) -> Tuple[List[ArgInfo], int]: num_found = 0 @@ -1058,6 +1124,8 @@ def _get_node_args( context, pipeline_preproc, for_preproc_module, + default_stream=default_stream, + dist_stream=dist_stream, ) kwargs_arg_info_list, num_found = _get_node_args_helper( model, @@ -1067,6 +1135,8 @@ def _get_node_args( context, pipeline_preproc, for_preproc_module, + default_stream=default_stream, + dist_stream=dist_stream, ) # Replace with proper names for kwargs @@ -1185,12 +1255,13 @@ def _pipeline_detach_model( # pyre-ignore[3] def _rewrite_model( # noqa C901 model: torch.nn.Module, - context: TrainPipelineContext, + context: TForwardContext, dist_stream: Optional[torch.Stream], batch: Optional[In] = None, apply_jit: bool = False, - pipelined_forward: Type[BaseForward] = PipelinedForward, + pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward, pipeline_preproc: bool = False, + default_stream: Optional[torch.Stream] = None, ) -> Tuple[ List[ShardedModule], torch.nn.Module, @@ -1249,6 +1320,8 @@ def _rewrite_model( # noqa C901 pipelined_preprocs, context, pipeline_preproc, + default_stream=default_stream, + dist_stream=dist_stream, ) if num_found == total_num_args: @@ -1535,8 +1608,9 @@ def __init__( Callable[[KeyedJaggedTensor], Awaitable[KJTAllToAllTensorsAwaitable]] ] = [] - self._pipelined_forward = ( - PrefetchPipelinedForward if prefetch_stream else PipelinedForward + self._pipelined_forward: Type[BaseForward[TrainPipelineContext]] = cast( + Type[BaseForward[TrainPipelineContext]], + (PrefetchPipelinedForward if prefetch_stream else PipelinedForward), ) self._default_stream: Optional[torch.Stream] = ( @@ -1588,6 +1662,7 @@ def start_sparse_data_dist(self, batch: In) -> In: batch=batch, apply_jit=self.apply_jit, pipelined_forward=self._pipelined_forward, + default_stream=self._default_stream, ) # initializes input dist, so we can override input dist forwards _start_data_dist(self._pipelined_modules, batch, self.context)