diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 1c9b6469..1e8fb354 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -91,7 +91,8 @@ reconstruction: init_post: True # if `init_processors`, set post-procesor is available # background subtraction (if dataset has corresponding background images) - direct_background_subtraction: False + direct_background_subtraction: False # True or False + learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64] # Hyperparameters for each method unrolled_fista: # for unrolled_fista diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 96301eec..6c288e87 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -112,7 +112,7 @@ def __init__( assert ( direct_background_subtraction is False ), "Cannot use direct_background_subtraction and background_network at the same time." - self.background_network = background_network + self.set_background_network(background_network) else: self.learned_background_subtraction = False self.background_network = None @@ -188,6 +188,13 @@ def set_post_process(self, post_process): self.post_process_param, ) = self._prepare_process_block(post_process) + def set_background_network(self, background_network): + ( + self.background_network, + self.background_network_model, + self.background_network_param, + ) = self._prepare_process_block(background_network) + def freeze_pre_process(self): """ Method for freezing the pre process block. @@ -261,7 +268,9 @@ def forward(self, batch, psfs=None, background=None): assert ( self.background_network is not None ), "If project_background is True, background_network must be defined." - self._data = self._data - self.background_network(background) + self._data = self._data - self.background_network( + background, self.background_network_param + ) if psfs is not None: # assert same shape diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 09547d47..9727452d 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -492,8 +492,23 @@ def train_learned(config): if name1 in dict_params2_post: dict_params2_post[name1].data.copy_(param1.data) + # check/prepare background subtraction + background_network = None if config.reconstruction.direct_background_subtraction: assert test_set.measured_bg and train_set.measured_bg + assert config.reconstruction.learned_background_subtraction is None + if config.reconstruction.learned_background_subtraction is not None: + assert config.reconstruction.direct_background_subtraction is False + assert test_set.measured_bg and train_set.measured_bg + + # create UnetRes for background subtraction + background_network, background_network_name = create_process_network( + network="UnetRes", + depth=len(config.reconstruction.learned_background_subtraction), + nc=config.reconstruction.learned_background_subtraction, + device=device, + device_ids=device_ids, + ) # create reconstruction algorithm if config.reconstruction.init is not None: @@ -523,6 +538,7 @@ def train_learned(config): learn_tk=config.reconstruction.unrolled_fista.learn_tk, pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, + background_network=background_network, skip_unrolled=config.reconstruction.skip_unrolled, return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False @@ -541,6 +557,7 @@ def train_learned(config): tau=config.reconstruction.unrolled_admm.tau, pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, + background_network=background_network, skip_unrolled=config.reconstruction.skip_unrolled, return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False @@ -556,6 +573,7 @@ def train_learned(config): K=config.reconstruction.trainable_inv.K, pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, + background_network=background_network, return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False ), @@ -570,6 +588,7 @@ def train_learned(config): psf_channels = 3 assert config.reconstruction.direct_background_subtraction is False, "Not supported" + assert config.reconstruction.learned_background_subtraction is None, "Not supported" recon = MultiWiener( in_channels=3, @@ -606,6 +625,9 @@ def train_learned(config): if post_process is not None: n_param = sum(p.numel() for p in post_process.parameters() if p.requires_grad) log.info(f"-- Post-process model with {n_param} parameters") + if background_network is not None: + n_param = sum(p.numel() for p in background_network.parameters() if p.requires_grad) + log.info(f"-- Background subtraction model with {n_param} parameters") log.info(f"Setup time : {time.time() - start_time} s") log.info(f"PSF shape : {psf.shape}")