Skip to content

Commit

Permalink
Update gunpowder_trainer.py to swap to(device) and float() (#312)
Browse files Browse the repository at this point in the history
One more place where you have to(device) and then float().
numpy default float is 64, so the to(device) will fail if the device is
MPS that doesn't support float64.
  • Loading branch information
rhoadesScholar authored Oct 24, 2024
2 parents 4152f7c + ee7c762 commit e58a01d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def iterate(self, num_iterations, model, optimizer, device):
param.grad = None

t_start_prediction = time.time()
predicted = model.forward(torch.as_tensor(raw[raw.roi]).to(device).float())
predicted = model.forward(torch.as_tensor(raw[raw.roi]).float().to(device))
predicted.retain_grad()
loss = self._loss.compute(
predicted,
Expand Down

0 comments on commit e58a01d

Please sign in to comment.