From 5999d7cf3218b58ec24d04379c87c00feff24b82 Mon Sep 17 00:00:00 2001 From: noakraicer Date: Thu, 29 Aug 2024 16:19:14 +0300 Subject: [PATCH] final changes --- configs/benchmark_hyperspectral.yaml | 2 +- lensless/recon/gd.py | 6 +++--- lensless/recon/recon.py | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/configs/benchmark_hyperspectral.yaml b/configs/benchmark_hyperspectral.yaml index 5f775875..53671cf4 100644 --- a/configs/benchmark_hyperspectral.yaml +++ b/configs/benchmark_hyperspectral.yaml @@ -32,7 +32,7 @@ huggingface: device: "cuda" # numbers of iterations to benchmark -n_iter_range: [5, 10, 20, 50, 100, 200, 300] +n_iter_range: [2000] # number of files to benchmark n_files: null # null for all files #How much should the image be downsampled diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index 0b6946d8..c5af193c 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -101,12 +101,12 @@ def reset(self): # torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values # ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] - self._image_est = torch.zeros((1,250,250,3)) + self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device) # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) - self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) + self._alpha = 1/4770.13 else: if self._initial_est is not None: @@ -123,7 +123,7 @@ def reset(self): self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): - diff = np.sum(self.mask * self._convolver.convolve(self._image_est), -1) - self._data # (H, W, 1) + diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1) return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels def _update(self, iter): diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 7de40520..b17b165f 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -571,6 +571,9 @@ def apply( for i in range(n_iter): self._update(i) + if i%50==0: + img = self._form_image() + if self.compensation_branch is not None and i < self._n_iter - 1: self.compensation_branch_inputs.append(self._form_image())