diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 2e764631..ff805426 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -51,6 +51,7 @@ def __init__( n_iter=1, pre_process=None, post_process=None, + background_network=None, skip_unrolled=False, skip_pre=False, skip_post=False, @@ -106,6 +107,12 @@ def __init__( self.skip_pre = skip_pre self.skip_post = skip_post self.direct_background_subtraction = direct_background_subtraction + if background_network is not None: + self.learned_background_subtraction = True + assert ( + direct_background_subtraction is False + ), "Cannot use direct_background_subtraction and background_network at the same time." + self.background_network = background_network self.return_intermediate = return_intermediate self.compensation_branch = compensation if compensation is not None: @@ -244,6 +251,14 @@ def forward(self, batch, psfs=None, background=None): background is not None ), "If direct_background_subtraction is True, background must be defined." self._data = self._data - background + elif self.learned_background_subtraction: + assert ( + background is not None + ), "If project_background is True, background must be defined." + assert ( + self.background_network is not None + ), "If project_background is True, background_network must be defined." + self._data = self._data - self.background_network(background) if psfs is not None: # assert same shape