Skip to content

Commit

Permalink
final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
noakraicer committed Aug 29, 2024
1 parent b09e498 commit 5999d7c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion configs/benchmark_hyperspectral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ huggingface:

device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [5, 10, 20, 50, 100, 200, 300]
n_iter_range: [2000]
# number of files to benchmark
n_files: null # null for all files
#How much should the image be downsampled
Expand Down
6 changes: 3 additions & 3 deletions lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def reset(self):
# torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
# ) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.zeros((1,250,250,3))
self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device)

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
self._alpha = 1/4770.13

else:
if self._initial_est is not None:
Expand All @@ -123,7 +123,7 @@ def reset(self):
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

def _grad(self):
diff = np.sum(self.mask * self._convolver.convolve(self._image_est), -1) - self._data # (H, W, 1)
diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1)
return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels

def _update(self, iter):
Expand Down
3 changes: 3 additions & 0 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ def apply(

for i in range(n_iter):
self._update(i)
if i%50==0:
img = self._form_image()

if self.compensation_branch is not None and i < self._n_iter - 1:
self.compensation_branch_inputs.append(self._form_image())

Expand Down

0 comments on commit 5999d7c

Please sign in to comment.