Skip to content

Commit

Permalink
Add flag for learned background subtraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Aug 14, 2024
1 parent 31637ea commit 2231fd7
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
n_iter=1,
pre_process=None,
post_process=None,
background_network=None,
skip_unrolled=False,
skip_pre=False,
skip_post=False,
Expand Down Expand Up @@ -106,6 +107,12 @@ def __init__(
self.skip_pre = skip_pre
self.skip_post = skip_post
self.direct_background_subtraction = direct_background_subtraction
if background_network is not None:
self.learned_background_subtraction = True
assert (
direct_background_subtraction is False
), "Cannot use direct_background_subtraction and background_network at the same time."
self.background_network = background_network
self.return_intermediate = return_intermediate
self.compensation_branch = compensation
if compensation is not None:
Expand Down Expand Up @@ -244,6 +251,14 @@ def forward(self, batch, psfs=None, background=None):
background is not None
), "If direct_background_subtraction is True, background must be defined."
self._data = self._data - background
elif self.learned_background_subtraction:
assert (
background is not None
), "If project_background is True, background must be defined."
assert (
self.background_network is not None
), "If project_background is True, background_network must be defined."
self._data = self._data - self.background_network(background)

if psfs is not None:
# assert same shape
Expand Down

0 comments on commit 2231fd7

Please sign in to comment.