diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index d42a2e9ac..e747a6283 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -1002,6 +1002,7 @@ def start_embedding_lookup( context, source_stream=self._data_dist_stream, target_stream=stream, + stream_context=self._stream_context, ) event = torch.get_device_module(self._device).Event() event.record() diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index efc772a90..d653c4e03 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -11,6 +11,7 @@ import itertools import logging from collections import defaultdict, OrderedDict +from contextlib import AbstractContextManager from dataclasses import dataclass, field from itertools import chain @@ -297,6 +298,14 @@ def __init__( f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" ) + device: torch.device = cast(torch.Stream, self._dist_stream).device + # pyre-ignore + self._stream_context = ( + torch.get_device_module(device).stream + if device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + @property def preproc_module(self) -> torch.nn.Module: return self._preproc_module @@ -341,8 +350,7 @@ def forward(self, *input, **kwargs) -> Any: with record_function(f"## sdd_input_preproc {self._context.index} ##"): # should be no-op as we call this in dist stream - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(self._dist_stream): + with self._stream_context(self._dist_stream): res = self._preproc_module(*args, **kwargs) # Ensure preproc modules output is safe to use from default stream later @@ -364,8 +372,7 @@ def forward(self, *input, **kwargs) -> Any: f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" ) - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(self._default_stream): + with self._stream_context(self._default_stream): # Cache results, only during _start_data_dist self._context.preproc_fwd_results[self._fqn] = res @@ -760,10 +767,11 @@ def _start_embedding_lookup( context: EmbeddingTrainPipelineContext, source_stream: Optional[torch.Stream], target_stream: Optional[torch.Stream], + # pyre-ignore[2] + stream_context: Callable[..., AbstractContextManager[Any, Any]], ) -> None: module_context = context.module_contexts[module.forward.name] - # pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream - with torch.cuda.stream(source_stream): + with stream_context(source_stream): kjt = context.input_dist_tensors_requests[module.forward.name].wait() if target_stream is not None: