Skip to content

Commit

Permalink
Add background subtraction network creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Aug 14, 2024
1 parent 6f76c05 commit bb3488e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ reconstruction:
init_post: True # if `init_processors`, set post-procesor is available

# background subtraction (if dataset has corresponding background images)
direct_background_subtraction: False
direct_background_subtraction: False # True or False
learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]

# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
Expand Down
13 changes: 11 additions & 2 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
assert (
direct_background_subtraction is False
), "Cannot use direct_background_subtraction and background_network at the same time."
self.background_network = background_network
self.set_background_network(background_network)
else:
self.learned_background_subtraction = False
self.background_network = None
Expand Down Expand Up @@ -188,6 +188,13 @@ def set_post_process(self, post_process):
self.post_process_param,
) = self._prepare_process_block(post_process)

def set_background_network(self, background_network):
(
self.background_network,
self.background_network_model,
self.background_network_param,
) = self._prepare_process_block(background_network)

def freeze_pre_process(self):
"""
Method for freezing the pre process block.
Expand Down Expand Up @@ -261,7 +268,9 @@ def forward(self, batch, psfs=None, background=None):
assert (
self.background_network is not None
), "If project_background is True, background_network must be defined."
self._data = self._data - self.background_network(background)
self._data = self._data - self.background_network(
background, self.background_network_param
)

if psfs is not None:
# assert same shape
Expand Down
22 changes: 22 additions & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,23 @@ def train_learned(config):
if name1 in dict_params2_post:
dict_params2_post[name1].data.copy_(param1.data)

# check/prepare background subtraction
background_network = None
if config.reconstruction.direct_background_subtraction:
assert test_set.measured_bg and train_set.measured_bg
assert config.reconstruction.learned_background_subtraction is None
if config.reconstruction.learned_background_subtraction is not None:
assert config.reconstruction.direct_background_subtraction is False
assert test_set.measured_bg and train_set.measured_bg

# create UnetRes for background subtraction
background_network, background_network_name = create_process_network(
network="UnetRes",
depth=len(config.reconstruction.learned_background_subtraction),
nc=config.reconstruction.learned_background_subtraction,
device=device,
device_ids=device_ids,
)

# create reconstruction algorithm
if config.reconstruction.init is not None:
Expand Down Expand Up @@ -523,6 +538,7 @@ def train_learned(config):
learn_tk=config.reconstruction.unrolled_fista.learn_tk,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
background_network=background_network,
skip_unrolled=config.reconstruction.skip_unrolled,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
Expand All @@ -541,6 +557,7 @@ def train_learned(config):
tau=config.reconstruction.unrolled_admm.tau,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
background_network=background_network,
skip_unrolled=config.reconstruction.skip_unrolled,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
Expand All @@ -556,6 +573,7 @@ def train_learned(config):
K=config.reconstruction.trainable_inv.K,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
background_network=background_network,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
),
Expand All @@ -570,6 +588,7 @@ def train_learned(config):
psf_channels = 3

assert config.reconstruction.direct_background_subtraction is False, "Not supported"
assert config.reconstruction.learned_background_subtraction is None, "Not supported"

recon = MultiWiener(
in_channels=3,
Expand Down Expand Up @@ -606,6 +625,9 @@ def train_learned(config):
if post_process is not None:
n_param = sum(p.numel() for p in post_process.parameters() if p.requires_grad)
log.info(f"-- Post-process model with {n_param} parameters")
if background_network is not None:
n_param = sum(p.numel() for p in background_network.parameters() if p.requires_grad)
log.info(f"-- Background subtraction model with {n_param} parameters")

log.info(f"Setup time : {time.time() - start_time} s")
log.info(f"PSF shape : {psf.shape}")
Expand Down

0 comments on commit bb3488e

Please sign in to comment.