From b36c1e58de97aac43eddc5fd38da9c6d1770d064 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 13 Feb 2024 21:41:46 -0500 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=8E=A8=20Black=20formatted.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- care_train.py | 17 ++++---- .../architectures/cnnectome_unet.py | 16 ++++--- .../architectures/nlayer_discriminator.py | 42 ++++++++++++++----- .../datasets/arrays/concat_array.py | 8 ++-- .../datasets/arrays/intensity_array_config.py | 1 - dacapo/experiments/tasks/CARE_task.py | 7 +++- dacapo/experiments/tasks/CARE_task_config.py | 7 ++-- .../experiments/tasks/CycleGAN_task_config.py | 2 +- dacapo/experiments/tasks/Pix2Pix_task.py | 7 +++- .../experiments/tasks/Pix2Pix_task_config.py | 7 ++-- dacapo/experiments/tasks/__init__.py | 4 +- .../experiments/tasks/evaluators/__init__.py | 4 +- .../tasks/evaluators/evaluation_scores.py | 1 - .../intensities_evaluation_scores.py | 3 +- .../tasks/evaluators/intensities_evaluator.py | 22 ++++++---- dacapo/experiments/tasks/losses/GANLoss.py | 18 ++++---- dacapo/experiments/tasks/losses/__init__.py | 2 +- .../post_processors/CARE_post_processor.py | 7 ++-- .../CycleGAN_post_processor.py | 7 ++-- .../tasks/post_processors/__init__.py | 6 +-- .../tasks/predictors/CARE_predictor.py | 8 ++-- .../tasks/predictors/CycleGANPredictor.py | 4 +- .../experiments/tasks/predictors/__init__.py | 2 +- .../experiments/trainers/gunpowder_trainer.py | 20 +++++---- dacapo/plot.py | 8 ++-- docs/source/conf.py | 23 +++++----- 26 files changed, 150 insertions(+), 103 deletions(-) diff --git a/care_train.py b/care_train.py index 606ba91b2..399bc0dd0 100644 --- a/care_train.py +++ b/care_train.py @@ -5,7 +5,10 @@ from torchsummary import summary # CARE task specific elements -from dacapo.experiments.datasplits.datasets.arrays import ZarrArrayConfig, IntensitiesArrayConfig +from dacapo.experiments.datasplits.datasets.arrays import ( + ZarrArrayConfig, + IntensitiesArrayConfig, +) from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig from dacapo.experiments.datasplits import TrainValidateDataSplitConfig from dacapo.experiments.architectures import CNNectomeUNetConfig @@ -40,17 +43,11 @@ ) raw_array_config_int = IntensitiesArrayConfig( - name="raw_norm", - source_array_config = raw_array_config_zarr, - min = 0., - max = 1. + name="raw_norm", source_array_config=raw_array_config_zarr, min=0.0, max=1.0 ) gt_array_config_int = IntensitiesArrayConfig( - name="gt_norm", - source_array_config = gt_array_config_zarr, - min = 0., - max = 1. + name="gt_norm", source_array_config=gt_array_config_zarr, min=0.0, max=1.0 ) dataset_config = RawGTDatasetConfig( @@ -152,4 +149,4 @@ """ RuntimeError: Can not downsample shape torch.Size([1, 128, 47, 47, 47]) with factor (2, 2, 2), mismatch in spatial dimension 2 -""" \ No newline at end of file +""" diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index 01a261d09..338ac6ea1 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -273,9 +273,11 @@ def __init__( self.l_conv = nn.ModuleList( [ ConvPass( - in_channels - if level == 0 - else num_fmaps * fmap_inc_factor ** (level - 1), + ( + in_channels + if level == 0 + else num_fmaps * fmap_inc_factor ** (level - 1) + ), num_fmaps * fmap_inc_factor**level, kernel_size_down[level], activation=activation, @@ -327,9 +329,11 @@ def __init__( + num_fmaps * fmap_inc_factor ** (level + (1 - upsample_channel_contraction[level])), - num_fmaps * fmap_inc_factor**level - if num_fmaps_out is None or level != 0 - else num_fmaps_out, + ( + num_fmaps * fmap_inc_factor**level + if num_fmaps_out is None or level != 0 + else num_fmaps_out + ), kernel_size_up[level], activation=activation, padding=padding, diff --git a/dacapo/experiments/architectures/nlayer_discriminator.py b/dacapo/experiments/architectures/nlayer_discriminator.py index 66a5967a3..c203fdbfb 100644 --- a/dacapo/experiments/architectures/nlayer_discriminator.py +++ b/dacapo/experiments/architectures/nlayer_discriminator.py @@ -4,6 +4,7 @@ import torch.nn as nn import functools + class NLayerDiscriminator(Architecture): """Defines a PatchGAN discriminator""" @@ -22,36 +23,57 @@ def __init__(self, architecture_config): n_layers: int = architecture_config.n_layers norm_layer = architecture_config.norm_layer - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 - sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) + nf_mult = min(2**n, 8) sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) + nf_mult = min(2**n_layers, 8) sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) + nn.LeakyReLU(0.2, True), ] - sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" - return self.model(input) \ No newline at end of file + return self.model(input) diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 122526b14..aceda2e77 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -99,9 +99,11 @@ def __getitem__(self, roi: Roi) -> np.ndarray: else self.default_array[roi] ) arrays = [ - self.source_arrays[channel][roi] - if channel in self.source_arrays - else default + ( + self.source_arrays[channel][roi] + if channel in self.source_arrays + else default + ) for channel in self.channels ] shapes = [array.shape for array in arrays] diff --git a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py index 0faf616c6..87281f69f 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py @@ -19,4 +19,3 @@ class IntensitiesArrayConfig(ArrayConfig): min: float = attr.ib(metadata={"help_text": "The minimum intensity in your data"}) max: float = attr.ib(metadata={"help_text": "The maximum intensity in your data"}) - diff --git a/dacapo/experiments/tasks/CARE_task.py b/dacapo/experiments/tasks/CARE_task.py index 30697b1b7..519cdf701 100644 --- a/dacapo/experiments/tasks/CARE_task.py +++ b/dacapo/experiments/tasks/CARE_task.py @@ -4,12 +4,15 @@ from .predictors import CAREPredictor from .task import Task + class CARETask(Task): """CAREPredictor.""" def __init__(self, task_config) -> None: """Create a `CARETask`.""" - self.predictor = CAREPredictor(num_channels=task_config.num_channels, dims=task_config.dims) + self.predictor = CAREPredictor( + num_channels=task_config.num_channels, dims=task_config.dims + ) self.loss = MSELoss() self.post_processor = CAREPostProcessor() - self.evaluator = IntensitiesEvaluator() \ No newline at end of file + self.evaluator = IntensitiesEvaluator() diff --git a/dacapo/experiments/tasks/CARE_task_config.py b/dacapo/experiments/tasks/CARE_task_config.py index dc6274ac4..fccae9333 100644 --- a/dacapo/experiments/tasks/CARE_task_config.py +++ b/dacapo/experiments/tasks/CARE_task_config.py @@ -16,12 +16,13 @@ class CARETaskConfig(TaskConfig): metadata={ "help_text": "Number of output channels for the image. " "Number of ouptut channels should match the number of channels in the ground truth." - }) - + }, + ) + dims: int = attr.ib( default=2, metadata={ "help_text": "Number of UNet dimensions. " "Number of dimensions should match the number of channels in the ground truth." - } + }, ) diff --git a/dacapo/experiments/tasks/CycleGAN_task_config.py b/dacapo/experiments/tasks/CycleGAN_task_config.py index 493252d6d..f63e90b31 100644 --- a/dacapo/experiments/tasks/CycleGAN_task_config.py +++ b/dacapo/experiments/tasks/CycleGAN_task_config.py @@ -17,5 +17,5 @@ class CycleGANTaskConfig(TaskConfig): metadata={ "help_text": "Number of output channels for the image. " "Number of ouptut channels should match the number of channels in the ground truth." - } + }, ) diff --git a/dacapo/experiments/tasks/Pix2Pix_task.py b/dacapo/experiments/tasks/Pix2Pix_task.py index 6ba1e2ee3..9a52fc77e 100644 --- a/dacapo/experiments/tasks/Pix2Pix_task.py +++ b/dacapo/experiments/tasks/Pix2Pix_task.py @@ -4,12 +4,15 @@ from .predictors import CAREPredictor from .task import Task + class Pix2PixTask(Task): """Pix2Pix Predictor.""" def __init__(self, task_config) -> None: """Create a `Pix2PixTask`.""" - self.predictor = Pix2Pix_predictor(num_channels=task_config.num_channels, dims=task_config.dims) + self.predictor = Pix2Pix_predictor( + num_channels=task_config.num_channels, dims=task_config.dims + ) self.loss = MSELoss() # TODO: change losses self.post_processor = CAREPostProcessor() # TODO: change post processor - self.evaluator = IntensitiesEvaluator() \ No newline at end of file + self.evaluator = IntensitiesEvaluator() diff --git a/dacapo/experiments/tasks/Pix2Pix_task_config.py b/dacapo/experiments/tasks/Pix2Pix_task_config.py index 4a49b6436..ca5751fad 100644 --- a/dacapo/experiments/tasks/Pix2Pix_task_config.py +++ b/dacapo/experiments/tasks/Pix2Pix_task_config.py @@ -16,12 +16,13 @@ class Pix2PixTaskConfig(TaskConfig): metadata={ "help_text": "Number of output channels for the image. " "Number of ouptut channels should match the number of channels in the ground truth." - }) - + }, + ) + dims: int = attr.ib( default=2, metadata={ "help_text": "Number of UNet dimensions. " "Number of dimensions should match the number of channels in the ground truth." - } + }, ) diff --git a/dacapo/experiments/tasks/__init__.py b/dacapo/experiments/tasks/__init__.py index 05eae210d..65ce71a5a 100644 --- a/dacapo/experiments/tasks/__init__.py +++ b/dacapo/experiments/tasks/__init__.py @@ -5,5 +5,5 @@ from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa -from .CARE_task_config import CARETaskConfig, CARETask # noqa -from .CycleGAN_task_config import CycleGANTaskConfig, CycleGANTask # noqa \ No newline at end of file +from .CARE_task_config import CARETaskConfig, CARETask # noqa +from .CycleGAN_task_config import CycleGANTaskConfig, CycleGANTask # noqa diff --git a/dacapo/experiments/tasks/evaluators/__init__.py b/dacapo/experiments/tasks/evaluators/__init__.py index 9fc295934..2daf37545 100644 --- a/dacapo/experiments/tasks/evaluators/__init__.py +++ b/dacapo/experiments/tasks/evaluators/__init__.py @@ -11,5 +11,5 @@ from .instance_evaluator import InstanceEvaluator # noqa -from .intensities_evaluation_scores import IntensitiesEvaluationScores # noqa -from .intensities_evaluator import IntensitiesEvaluator # noqa +from .intensities_evaluation_scores import IntensitiesEvaluationScores # noqa +from .intensities_evaluator import IntensitiesEvaluator # noqa diff --git a/dacapo/experiments/tasks/evaluators/evaluation_scores.py b/dacapo/experiments/tasks/evaluators/evaluation_scores.py index 8909a790a..fce810cce 100644 --- a/dacapo/experiments/tasks/evaluators/evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/evaluation_scores.py @@ -1,4 +1,3 @@ - import attr from abc import abstractmethod diff --git a/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py index 6e85902d6..60dd56a13 100644 --- a/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/intensities_evaluation_scores.py @@ -3,9 +3,10 @@ from typing import Tuple + @attr.s class IntensitiesEvaluationScores(EvaluationScores): - criteria: property = ['ssim', 'psnr', 'nrmse'] + criteria: property = ["ssim", "psnr", "nrmse"] ssim: float = attr.ib(default=float("nan")) psnr: float = attr.ib(default=float("nan")) diff --git a/dacapo/experiments/tasks/evaluators/intensities_evaluator.py b/dacapo/experiments/tasks/evaluators/intensities_evaluator.py index a18c2cc39..81b8a25d5 100644 --- a/dacapo/experiments/tasks/evaluators/intensities_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/intensities_evaluator.py @@ -1,7 +1,11 @@ import xarray as xr from dacapo.experiments.datasplits.datasets.arrays import ZarrArray import numpy as np -from skimage.metrics import structural_similarity, peak_signal_noise_ratio, normalized_root_mse +from skimage.metrics import ( + structural_similarity, + peak_signal_noise_ratio, + normalized_root_mse, +) from .evaluator import Evaluator from .intensities_evaluation_scores import IntensitiesEvaluationScores @@ -13,17 +17,21 @@ class IntensitiesEvaluator(Evaluator): An evaluator takes a post-processor's output and compares it against ground-truth. """ + criteria = ["ssim", "psnr", "nrmse"] - - def evaluate(self, output_array_identifier, evaluation_array) -> IntensitiesEvaluationScores: + + def evaluate( + self, output_array_identifier, evaluation_array + ) -> IntensitiesEvaluationScores: output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) - return IntensitiesEvaluationScores(ssim=structural_similarity(evaluation_data, output_data), - psnr=peak_signal_noise_ratio(evaluation_data, output_data), - nrmse=normalized_root_mse(evaluation_data, output_data)) + return IntensitiesEvaluationScores( + ssim=structural_similarity(evaluation_data, output_data), + psnr=peak_signal_noise_ratio(evaluation_data, output_data), + nrmse=normalized_root_mse(evaluation_data, output_data), + ) @property def score(self) -> IntensitiesEvaluationScores: return IntensitiesEvaluationScores() - diff --git a/dacapo/experiments/tasks/losses/GANLoss.py b/dacapo/experiments/tasks/losses/GANLoss.py index d4a6fe8f7..57bcf41f8 100644 --- a/dacapo/experiments/tasks/losses/GANLoss.py +++ b/dacapo/experiments/tasks/losses/GANLoss.py @@ -10,7 +10,7 @@ class GANLoss(Loss): """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): - """ Initialize the GANLoss class. + """Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image @@ -19,17 +19,17 @@ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() - self.register_buffer('real_label', torch.tensor(target_real_label)) - self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.register_buffer("real_label", torch.tensor(target_real_label)) + self.register_buffer("fake_label", torch.tensor(target_fake_label)) self.gan_mode = gan_mode - if gan_mode == 'lsgan': + if gan_mode == "lsgan": self.loss = nn.MSELoss() - elif gan_mode == 'vanilla': + elif gan_mode == "vanilla": self.loss = nn.BCEWithLogitsLoss() - elif gan_mode in ['wgangp']: + elif gan_mode in ["wgangp"]: self.loss = None else: - raise NotImplementedError('gan mode %s not implemented' % gan_mode) + raise NotImplementedError("gan mode %s not implemented" % gan_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. @@ -54,10 +54,10 @@ def __call__(self, prediction, target_is_real): Returns: the calculated loss. """ - if self.gan_mode in ['lsgan', 'vanilla']: + if self.gan_mode in ["lsgan", "vanilla"]: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) - elif self.gan_mode == 'wgangp': + elif self.gan_mode == "wgangp": if target_is_real: loss = -prediction.mean() else: diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index c50f4fa9e..f14fae958 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,4 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa -from .GANLoss import GANLoss # noqa \ No newline at end of file +from .GANLoss import GANLoss # noqa diff --git a/dacapo/experiments/tasks/post_processors/CARE_post_processor.py b/dacapo/experiments/tasks/post_processors/CARE_post_processor.py index 88729158e..cc0af3742 100644 --- a/dacapo/experiments/tasks/post_processors/CARE_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/CARE_post_processor.py @@ -13,7 +13,6 @@ from dacapo.experiments.tasks.post_processors import PostProcessorParameters - class CAREPostProcessor(PostProcessor): def __init__(self) -> None: super().__init__() @@ -23,7 +22,9 @@ def enumerate_parameters(self) -> Iterable[CAREPostProcessorParameters]: yield CAREPostProcessorParameters(id=1) - def set_prediction(self, prediction_array_identifier: "LocalArrayIdentifier"): # TODO + def set_prediction( + self, prediction_array_identifier: "LocalArrayIdentifier" + ): # TODO self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -32,7 +33,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - ) -> ZarrArray: + ) -> ZarrArray: output_array: ZarrArray = ZarrArray.create_from_array_identifier( output_array_identifier, diff --git a/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py index 735902afb..8928c1330 100644 --- a/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/CycleGAN_post_processor.py @@ -13,7 +13,6 @@ from dacapo.experiments.tasks.post_processors import PostProcessorParameters - class CycleGANPostProcessor(PostProcessor): def __init__(self) -> None: super().__init__() @@ -23,7 +22,9 @@ def enumerate_parameters(self) -> Iterable[CycleGANPostProcessorParameters]: yield CycleGANPostProcessorParameters(id=1) - def set_prediction(self, prediction_array_identifier: "LocalArrayIdentifier"): # TODO + def set_prediction( + self, prediction_array_identifier: "LocalArrayIdentifier" + ): # TODO self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -32,7 +33,7 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - ) -> ZarrArray: + ) -> ZarrArray: output_array: ZarrArray = ZarrArray.create_from_array_identifier( output_array_identifier, diff --git a/dacapo/experiments/tasks/post_processors/__init__.py b/dacapo/experiments/tasks/post_processors/__init__.py index f2852641f..e3794594d 100644 --- a/dacapo/experiments/tasks/post_processors/__init__.py +++ b/dacapo/experiments/tasks/post_processors/__init__.py @@ -13,8 +13,8 @@ WatershedPostProcessorParameters, ) # noqa -from .CARE_post_processor import CAREPostProcessor # noqa +from .CARE_post_processor import CAREPostProcessor # noqa from .CARE_post_processor_parameters import CAREPostProcessorParameters # noqa -from .CycleGAN_post_processor import CycleGANPostProcessor # noqa -from .CycleGAN_post_processor_parameters import CycleGANPostProcessorParameters # noqa \ No newline at end of file +from .CycleGAN_post_processor import CycleGANPostProcessor # noqa +from .CycleGAN_post_processor_parameters import CycleGANPostProcessorParameters # noqa diff --git a/dacapo/experiments/tasks/predictors/CARE_predictor.py b/dacapo/experiments/tasks/predictors/CARE_predictor.py index 6906bcd5f..a52fa2ef2 100644 --- a/dacapo/experiments/tasks/predictors/CARE_predictor.py +++ b/dacapo/experiments/tasks/predictors/CARE_predictor.py @@ -47,6 +47,8 @@ def create_weight(self, gt, target=None, mask=None): @property def output_array_type(self): - return IntensitiesArray({"channels": {n: str(n) for n in range(self.num_channels)}}, min=0., max=1.) - - + return IntensitiesArray( + {"channels": {n: str(n) for n in range(self.num_channels)}}, + min=0.0, + max=1.0, + ) diff --git a/dacapo/experiments/tasks/predictors/CycleGANPredictor.py b/dacapo/experiments/tasks/predictors/CycleGANPredictor.py index 8eaaffbf3..8a75482d4 100644 --- a/dacapo/experiments/tasks/predictors/CycleGANPredictor.py +++ b/dacapo/experiments/tasks/predictors/CycleGANPredictor.py @@ -50,7 +50,7 @@ def create_weight(self, gt): @property def output_array_type(self): return ZarrArray(self.num_channels) - + def gt_region_for_roi(self, target_spec): if self.mask_distances: gt_spec = target_spec.copy() @@ -64,5 +64,3 @@ def gt_region_for_roi(self, target_spec): def padding(self, gt_voxel_size: Coordinate) -> Coordinate: return Coordinate((self.max_distance,) * gt_voxel_size.dims) - - diff --git a/dacapo/experiments/tasks/predictors/__init__.py b/dacapo/experiments/tasks/predictors/__init__.py index efdd8db6e..040a9b874 100644 --- a/dacapo/experiments/tasks/predictors/__init__.py +++ b/dacapo/experiments/tasks/predictors/__init__.py @@ -3,4 +3,4 @@ from .one_hot_predictor import OneHotPredictor # noqa from .predictor import Predictor # noqa from .affinities_predictor import AffinitiesPredictor # noqa -from .CARE_predictor import CAREPredictor # noqa \ No newline at end of file +from .CARE_predictor import CAREPredictor # noqa diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index efec630f0..75f06bbd1 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -132,12 +132,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): + gp.Pad(gt_key, None, 0) + gp.Pad(mask_key, None, 0) + gp.RandomLocation( - ensure_nonempty=sample_points_key - if points_source is not None - else None, - ensure_centered=sample_points_key - if points_source is not None - else None, + ensure_nonempty=( + sample_points_key if points_source is not None else None + ), + ensure_centered=( + sample_points_key if points_source is not None else None + ), ) ) @@ -323,9 +323,11 @@ def next(self): NumpyArray.from_gp_array(batch[self._gt_key]), NumpyArray.from_gp_array(batch[self._target_key]), NumpyArray.from_gp_array(batch[self._weight_key]), - NumpyArray.from_gp_array(batch[self._mask_key]) - if self._mask_key is not None - else None, + ( + NumpyArray.from_gp_array(batch[self._mask_key]) + if self._mask_key is not None + else None + ), ) def __enter__(self): diff --git a/dacapo/plot.py b/dacapo/plot.py index c1e02ec95..fcd1c6ee2 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -74,9 +74,11 @@ def get_runs_info( run_config.architecture_config.name, run_config.trainer_config.name, run_config.datasplit_config.name, - stats_store.retrieve_training_stats(run_config_name, subsample=True) - if plot_loss - else None, + ( + stats_store.retrieve_training_stats(run_config_name, subsample=True) + if plot_loss + else None + ), validation_scores, validation_score_name, plot_loss, diff --git a/docs/source/conf.py b/docs/source/conf.py index cd5823612..7df2f563b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) # -- Project information ----------------------------------------------------- -project = 'DaCapo' -copyright = '2022, William Patton, David Ackerman, Jan Funke' -author = 'William Patton, David Ackerman, Jan Funke' +project = "DaCapo" +copyright = "2022, William Patton, David Ackerman, Jan Funke" +author = "William Patton, David Ackerman, Jan Funke" # -- General configuration --------------------------------------------------- @@ -27,15 +28,15 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_autodoc_typehints'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -43,12 +44,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_material' +html_theme = "sphinx_material" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_css_files = [ - 'css/custom.css', -] \ No newline at end of file + "css/custom.css", +]