-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
176 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# ############################################################################# | ||
# trainable_inversion.py | ||
# ================= | ||
# Authors : | ||
# Eric BEZZAM [[email protected]] | ||
# ############################################################################# | ||
|
||
from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm | ||
|
||
|
||
class TrainableInversion(TrainableReconstructionAlgorithm): | ||
""" """ | ||
|
||
def __init__(self, psf, dtype=None, K=1e-4, **kwargs): | ||
""" | ||
Constructor for trainable inversion component as proposed in | ||
the FlatNet work: https://siddiquesalman.github.io/flatnet/ | ||
Parameters | ||
---------- | ||
psf : :py:class:`~torch.Tensor` | ||
Point spread function (PSF) that models forward propagation. | ||
Must be of shape (depth, height, width, channels) even if | ||
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` | ||
to load a PSF from a file such that it is in the correct format. | ||
dtype : float32 or float64 | ||
Data type to use for optimization. | ||
K : float | ||
Regularization parameter. | ||
""" | ||
|
||
super(TrainableInversion, self).__init__(psf, n_iter=1, dtype=dtype, reset=False, **kwargs) | ||
self._convolver._Hadj = self._convolver._Hadj / (self._convolver._H.norm() ** 2 + K) | ||
|
||
self.reset() | ||
|
||
def _form_image(self): | ||
self._image_est[self._image_est < 0] = 0 | ||
return self._image_est | ||
|
||
def _set_psf(self, psf): | ||
return super()._set_psf(psf) | ||
|
||
def reset(self, batch_size=1): | ||
# no state variables | ||
return | ||
|
||
def _update(self, iter): | ||
self._image_est = self._convolver.deconvolve(self._data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters