Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Oct 2, 2024
1 parent 6d7d5d8 commit 6c922da
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
13 changes: 9 additions & 4 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def module(self):
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
activation="ReLU",
batch_norm= self.batch_norm,
batch_norm=self.batch_norm,
)
layers.append(conv)
unet = torch.nn.Sequential(*layers)
Expand Down Expand Up @@ -1307,7 +1307,7 @@ class AttentionBlockModule(nn.Module):
The AttentionBlockModule is an instance of the ``torch.nn.Module`` class.
"""

def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True):
def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True):
"""
Initialize the Attention Block Module.
Expand Down Expand Up @@ -1337,7 +1337,12 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True):
self.upsample_factor = (2,) * self.dims

self.W_g = ConvPass(
F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same",batch_norm=self.batch_norm
F_g,
F_int,
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm,
)

self.W_x = nn.Sequential(
Expand All @@ -1347,7 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True):
kernel_sizes=self.kernel_sizes,
activation=None,
padding="same",
batch_norm=self.batch_norm
batch_norm=self.batch_norm,
),
Downsample(upsample_factor),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def create_from_array_identifier(
"""
Create a new ZarrArray given an array identifier. It is assumed that
this array_identifier points to a dataset that does not yet exist.
Args:
array_identifier (ArrayIdentifier): The array identifier.
axes (List[str]): The axes of the array.
Expand Down
4 changes: 3 additions & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def train_run(run: Run, do_validate=True):
weights = weights_store.retrieve_weights(run, iteration=trained_until)

elif latest_weights_iteration > trained_until:
weights = weights_store.retrieve_weights(run, iteration=latest_weights_iteration)
weights = weights_store.retrieve_weights(
run, iteration=latest_weights_iteration
)
logger.error(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}. "
Expand Down

0 comments on commit 6c922da

Please sign in to comment.