Skip to content

Commit

Permalink
Merge pull request #34 from janelia-cellmap/actions/black
Browse files Browse the repository at this point in the history
Format Python code with psf/black push
  • Loading branch information
mzouink authored Feb 9, 2024
2 parents e1806fe + a9764c8 commit 9d2df1a
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 92 deletions.
128 changes: 70 additions & 58 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,31 +320,29 @@ def __init__(
for _ in range(num_heads)
]
)
# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out
# if num_fmaps_out is None or level != self.num_levels-1 else num_fmaps_out
if self.use_attention:
self.attention = nn.ModuleList(
[
nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (level + 1),
F_l=num_fmaps
* fmap_inc_factor
** level,
F_int=num_fmaps
* fmap_inc_factor
** (level + (1 - upsample_channel_contraction[level]))
if num_fmaps_out is None or level != 0
else num_fmaps_out,
dims=self.dims,
upsample_factor=downsample_factors[level],
nn.ModuleList(
[
AttentionBlockModule(
F_g=num_fmaps * fmap_inc_factor ** (level + 1),
F_l=num_fmaps * fmap_inc_factor**level,
F_int=num_fmaps
* fmap_inc_factor
** (level + (1 - upsample_channel_contraction[level]))
if num_fmaps_out is None or level != 0
else num_fmaps_out,
dims=self.dims,
upsample_factor=downsample_factors[level],
)
for level in range(self.num_levels - 1)
]
)
for level in range(self.num_levels - 1)
for _ in range(num_heads)
]
)
for _ in range(num_heads)
]
)

# right convolutional passes
self.r_conv = nn.ModuleList(
Expand Down Expand Up @@ -389,12 +387,15 @@ def rec_forward(self, level, f_in):
gs_out = self.rec_forward(level - 1, g_in)

if self.use_attention:
f_left_attented = [self.attention[h][i](gs_out[h],f_left) for h in range(self.num_heads)]
f_left_attented = [
self.attention[h][i](gs_out[h], f_left)
for h in range(self.num_heads)
]
fs_right = [
self.r_up[h][i](gs_out[h], f_left_attented[h])
for h in range(self.num_heads)
]
else: # up, concat, and crop
else: # up, concat, and crop
fs_right = [
self.r_up[h][i](gs_out[h], f_left) for h in range(self.num_heads)
]
Expand Down Expand Up @@ -617,68 +618,80 @@ def forward(self, g_out, f_left=None):
return g_cropped



class AttentionBlockModule(nn.Module):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None):
"""Attention Block Module::
The attention block takes two inputs: 'g' (gating signal) and 'x' (input features).
The attention block takes two inputs: 'g' (gating signal) and 'x' (input features).
[g] --> W_g --\ /--> psi --> * --> [output]
\ /
[x] --> W_x --> [+] --> relu --
[g] --> W_g --\ /--> psi --> * --> [output]
\ /
[x] --> W_x --> [+] --> relu --
Where:
- W_g and W_x are 1x1 Convolution followed by Batch Normalization
- [+] indicates element-wise addition
- relu is the Rectified Linear Unit activation function
- psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation
- * indicates element-wise multiplication between the output of psi and input feature 'x'
- [output] has the same dimensions as input 'x', selectively emphasized by attention weights
Where:
- W_g and W_x are 1x1 Convolution followed by Batch Normalization
- [+] indicates element-wise addition
- relu is the Rectified Linear Unit activation function
- psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation
- * indicates element-wise multiplication between the output of psi and input feature 'x'
- [output] has the same dimensions as input 'x', selectively emphasized by attention weights
Args:
F_g (int): The number of feature channels in the gating signal (g).
This is the input channel dimension for the W_g convolutional layer.
Args:
F_g (int): The number of feature channels in the gating signal (g).
This is the input channel dimension for the W_g convolutional layer.
F_l (int): The number of feature channels in the input features (x).
This is the input channel dimension for the W_x convolutional layer.
F_l (int): The number of feature channels in the input features (x).
This is the input channel dimension for the W_x convolutional layer.
F_int (int): The number of intermediate feature channels.
This represents the output channel dimension of the W_g and W_x convolutional layers
and the input channel dimension for the psi layer. Typically, F_int is smaller
than F_g and F_l, as it serves to compress the feature representations before
applying the attention mechanism.
F_int (int): The number of intermediate feature channels.
This represents the output channel dimension of the W_g and W_x convolutional layers
and the input channel dimension for the psi layer. Typically, F_int is smaller
than F_g and F_l, as it serves to compress the feature representations before
applying the attention mechanism.
The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them,
and applies a sigmoid activation to generate an attention map. This map is then used
to scale the input features 'x', resulting in an output that focuses on important
features as dictated by the gating signal 'g'.
The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them,
and applies a sigmoid activation to generate an attention map. This map is then used
to scale the input features 'x', resulting in an output that focuses on important
features as dictated by the gating signal 'g'.
"""
"""

super(AttentionBlockModule, self).__init__()
self.dims = dims
self.kernel_sizes = [(1,) * self.dims, (1,) * self.dims]
if upsample_factor is not None:
self.upsample_factor = upsample_factor
else:
self.upsample_factor = (2,)*self.dims
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same")
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same"
)

self.W_x = nn.Sequential(
ConvPass(F_l, F_int, kernel_sizes=self.kernel_sizes,
activation=None, padding="same"),
Downsample(upsample_factor)
ConvPass(
F_l,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
),
Downsample(upsample_factor),
)

self.psi = ConvPass(
F_int, 1, kernel_sizes=self.kernel_sizes, activation="Sigmoid", padding="same")
F_int,
1,
kernel_sizes=self.kernel_sizes,
activation="Sigmoid",
padding="same",
)

up_mode = {2: 'bilinear', 3: 'trilinear'}[self.dims]
up_mode = {2: "bilinear", 3: "trilinear"}[self.dims]

self.up = nn.Upsample(scale_factor=upsample_factor, mode=up_mode, align_corners=True)
self.up = nn.Upsample(
scale_factor=upsample_factor, mode=up_mode, align_corners=True
)

self.relu = nn.ReLU(inplace=True)

Expand All @@ -702,8 +715,7 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor):
padding = padding[::-1]

# Apply symmetric padding
return nn.functional.pad(smaller_tensor, padding, mode='constant', value=0)

return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0)

def forward(self, g, x):
g1 = self.W_g(g)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

logger = logging.getLogger(__file__)


class ConcatArray(Array):
"""This is a wrapper around other `source_arrays` that concatenates
them along the channel dimension."""
Expand Down Expand Up @@ -119,7 +120,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
logger.info(
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
21 changes: 14 additions & 7 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__file__)


class Run:
name: str
train_until: int
Expand Down Expand Up @@ -58,28 +59,34 @@ def __init__(self, run_config):
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
38 changes: 25 additions & 13 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]
# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]

# Hack
# if label is mito_peroxisome or peroxisome then change it to mito
mitos = ["mito_proxisome","peroxisome"]
mitos = ["mito_proxisome", "peroxisome"]

def match_heads(model, head_weights, old_head, new_head ):

def match_heads(model, head_weights, old_head, new_head):
# match the heads
for label in new_head:
old_label = label
Expand All @@ -30,8 +36,9 @@ def match_heads(model, head_weights, old_head, new_head ):
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label} with {old_label}")


class Start(ABC):
def __init__(self, start_config,remove_head = False, old_head= None, new_head = None):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
Expand All @@ -44,7 +51,9 @@ def initialize_weights(self, model):
weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"loading weights from run {self.run}, criterion: {self.criterion}"
)

try:
if self.old_head and self.new_head:
Expand All @@ -61,15 +70,21 @@ def initialize_weights(self, model):
logger.warning(f"ERROR starter: {e}")

def load_model_using_head_removal(self, model, weights):
logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
for key in head_keys:
weights.model.pop(key, None)
logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}")
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)

def load_model_using_head_matching(self, model, weights):
logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.warning(f"old head: {self.old_head}")
logger.warning(f"new head: {self.new_head}")
head_weights = {}
Expand All @@ -79,6 +94,3 @@ def load_model_using_head_matching(self, model, weights):
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, self.old_head, self.new_head)



10 changes: 8 additions & 2 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
extra_conv: bool,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -37,7 +43,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.epsilon = 5e-2
self.threshold = 0.8
self.extra_conv = extra_conv
self.extra_conv_dims =len(self.channels) *2
self.extra_conv_dims = len(self.channels) * 2

@property
def embedding_dims(self):
Expand Down
Loading

0 comments on commit 9d2df1a

Please sign in to comment.