From 890efa228feed2b41ed3949a4e2bca723f7a8ddf Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Mon, 28 Oct 2024 15:45:11 +0100 Subject: [PATCH 1/2] Corrected ".to" statement for DDP case --- mala/network/runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index a67a79eb0..a8910c2ad 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -850,7 +850,12 @@ def _forward_entire_snapshot( # Ensure the Network is on the correct device. # This line is necessary because GPU acceleration may have been # activated AFTER loading a model. - self.network.to(self.network.params._configuration["device"]) + if self.parameters_full.use_ddp: + self.network.module.to( + self.network.params._configuration["device"] + ) + else: + self.network.to(self.network.params._configuration["device"]) # Determine where the snapshot begins and ends. from_index = 0 From 1881d106fb39538e5477f851b13fea1694c67b2e Mon Sep 17 00:00:00 2001 From: Lenz Fiedler Date: Tue, 29 Oct 2024 14:57:48 +0100 Subject: [PATCH 2/2] Corrected one line only to have the second one be incorrect too --- mala/network/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mala/network/runner.py b/mala/network/runner.py index a8910c2ad..ff111b10b 100644 --- a/mala/network/runner.py +++ b/mala/network/runner.py @@ -852,7 +852,7 @@ def _forward_entire_snapshot( # activated AFTER loading a model. if self.parameters_full.use_ddp: self.network.module.to( - self.network.params._configuration["device"] + self.network.module.params._configuration["device"] ) else: self.network.to(self.network.params._configuration["device"])