diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index d1bcd208c..d89e902ac 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -282,7 +282,7 @@ def module(self): self.fmaps_out, [(3,) * len(upsample_factor)] * 2, activation="ReLU", - batch_norm= self.batch_norm, + batch_norm=self.batch_norm, ) layers.append(conv) unet = torch.nn.Sequential(*layers) @@ -1307,7 +1307,7 @@ class AttentionBlockModule(nn.Module): The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. """ - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None, batch_norm=True): """ Initialize the Attention Block Module. @@ -1337,7 +1337,12 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): self.upsample_factor = (2,) * self.dims self.W_g = ConvPass( - F_g, F_int, kernel_sizes=self.kernel_sizes, activation=None, padding="same",batch_norm=self.batch_norm + F_g, + F_int, + kernel_sizes=self.kernel_sizes, + activation=None, + padding="same", + batch_norm=self.batch_norm, ) self.W_x = nn.Sequential( @@ -1347,7 +1352,7 @@ def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None,batch_norm=True): kernel_sizes=self.kernel_sizes, activation=None, padding="same", - batch_norm=self.batch_norm + batch_norm=self.batch_norm, ), Downsample(upsample_factor), ) diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 88359c783..f9a26bd09 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -434,7 +434,7 @@ def create_from_array_identifier( """ Create a new ZarrArray given an array identifier. It is assumed that this array_identifier points to a dataset that does not yet exist. - + Args: array_identifier (ArrayIdentifier): The array identifier. axes (List[str]): The axes of the array. diff --git a/dacapo/train.py b/dacapo/train.py index bfd06eeb1..eb28a3cf7 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -113,7 +113,9 @@ def train_run(run: Run, do_validate=True): weights = weights_store.retrieve_weights(run, iteration=trained_until) elif latest_weights_iteration > trained_until: - weights = weights_store.retrieve_weights(run, iteration=latest_weights_iteration) + weights = weights_store.retrieve_weights( + run, iteration=latest_weights_iteration + ) logger.error( f"Found weights for iteration {latest_weights_iteration}, but " f"run {run.name} was only trained until {trained_until}. "