From bb9268b10d0640e16c87fce10a44746f3274d1ca Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:58:18 +0200 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20support=20for=20t?= =?UTF-8?q?raining=20on=20Apple=20M1/M2/M3=20(mps)=20devices.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also add log printing of device used for training. --- dacapo/compute_context/local_torch.py | 2 ++ dacapo/train.py | 1 + 2 files changed, 3 insertions(+) diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index 5a0371a43..045300790 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -60,6 +60,8 @@ def device(self): if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM return torch.device("cpu") return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") else: return torch.device("cpu") return torch.device(self._device) diff --git a/dacapo/train.py b/dacapo/train.py index 70b845db2..4e4101f8d 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -135,6 +135,7 @@ def train_run(run: Run, do_validate=True): compute_context = create_compute_context() run.model = run.model.to(compute_context.device) run.move_optimizer(compute_context.device) + logger.info(f"Training on {compute_context.device}") array_store = create_array_store() run.trainer.iteration = trained_until From 4152f7c4baa9107558b3702a56c4fd9544299e60 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:18:11 +0200 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20float64=20for?= =?UTF-8?q?=20mps=20device.=20(convert=20to=20float32=20first)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/experiments/trainers/gunpowder_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index dffd28e17..4b5649e77 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -313,8 +313,8 @@ def iterate(self, num_iterations, model, optimizer, device): predicted.retain_grad() loss = self._loss.compute( predicted, - torch.as_tensor(target[target.roi]).to(device).float(), - torch.as_tensor(weight[weight.roi]).to(device).float(), + torch.as_tensor(target[target.roi]).float().to(device), + torch.as_tensor(weight[weight.roi]).float().to(device), ) loss.backward() optimizer.step() From ee7c762725ea6d86663b30e612bf285621522d32 Mon Sep 17 00:00:00 2001 From: Peter Sobolewski <76622105+psobolewskiPhD@users.noreply.github.com> Date: Thu, 24 Oct 2024 11:54:12 +0200 Subject: [PATCH 3/3] Update gunpowder_trainer.py to swap to(device) and float() --- dacapo/experiments/trainers/gunpowder_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 4b5649e77..104c5fa9c 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -309,7 +309,7 @@ def iterate(self, num_iterations, model, optimizer, device): param.grad = None t_start_prediction = time.time() - predicted = model.forward(torch.as_tensor(raw[raw.roi]).to(device).float()) + predicted = model.forward(torch.as_tensor(raw[raw.roi]).float().to(device)) predicted.retain_grad() loss = self._loss.compute( predicted,