diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 094b4a236..3cd0f990b 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -249,6 +249,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: True ) torch._dynamo.config.skip_torchrec = False + torch._dynamo.config.optimize_ddp = False # Importing only before compilation to not slow-done train_pipelines import torch.ops.import_module("fbgemm_gpu.sparse_ops")