Skip to content

Commit

Permalink
Update gunpowder_trainer.py to swap to(device) and float()
Browse files Browse the repository at this point in the history
  • Loading branch information
psobolewskiPhD authored Oct 24, 2024
1 parent 4152f7c commit ee7c762
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def iterate(self, num_iterations, model, optimizer, device):
param.grad = None

t_start_prediction = time.time()
predicted = model.forward(torch.as_tensor(raw[raw.roi]).to(device).float())
predicted = model.forward(torch.as_tensor(raw[raw.roi]).float().to(device))
predicted.retain_grad()
loss = self._loss.compute(
predicted,
Expand Down

0 comments on commit ee7c762

Please sign in to comment.