diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index 1f54b75fd..f3af1edc0 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -167,12 +167,11 @@ def init_model_fn( size_based_auto_wrap_policy, min_num_params=2 ** 10 ) model = FSDP( - self._model, + model, use_orig_params=True, auto_wrap_policy=auto_wrap_policy, device_id=RANK ) - else: model = torch.nn.DataParallel(model) return model, None