Skip to content

Commit

Permalink
Merge pull request #602 from RandomDefaultUser/fix_ddp_training
Browse files Browse the repository at this point in the history
Fixing DDP training
  • Loading branch information
RandomDefaultUser authored Oct 30, 2024
2 parents 27dec15 + 1881d10 commit b3d117e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mala/network/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,12 @@ def _forward_entire_snapshot(
# Ensure the Network is on the correct device.
# This line is necessary because GPU acceleration may have been
# activated AFTER loading a model.
self.network.to(self.network.params._configuration["device"])
if self.parameters_full.use_ddp:
self.network.module.to(
self.network.module.params._configuration["device"]
)
else:
self.network.to(self.network.params._configuration["device"])

# Determine where the snapshot begins and ends.
from_index = 0
Expand Down

0 comments on commit b3d117e

Please sign in to comment.