From 4b703fe9dc9d1679085174f9573ad54418c1a6b9 Mon Sep 17 00:00:00 2001 From: mzouink Date: Tue, 19 Nov 2024 14:38:35 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- dacapo/experiments/tasks/hot_distance_task.py | 1 + .../tasks/hot_distance_task_config.py | 3 +-- .../post_processors/argmax_post_processor.py | 2 +- .../post_processors/threshold_post_processor.py | 16 +++++++++------- .../tasks/predictors/distance_predictor.py | 2 +- .../tasks/predictors/hot_distance_predictor.py | 8 +++++++- dacapo/predict_local.py | 6 +++--- tests/operations/test_architecture.py | 1 - tests/operations/test_mini.py | 8 +++++++- tests/operations/test_train.py | 1 - tests/operations/test_validate.py | 1 - 11 files changed, 30 insertions(+), 19 deletions(-) diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index 382eaf8b..3d86da13 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -6,6 +6,7 @@ import warnings + class HotDistanceTask(Task): """ A class to represent a hot distance task that use binary prediction and distance prediction. diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index d140e38e..7e0cc37a 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -57,7 +57,6 @@ class HotDistanceTaskConfig(TaskConfig): }, ) - kernel_size: int | None = attr.ib( default=None, - ) \ No newline at end of file + ) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index a88b8926..34cb0245 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -133,7 +133,7 @@ def process( overwrite=True, ) - read_roi = Roi((0,)*block_size.dims, block_size) + read_roi = Roi((0,) * block_size.dims, block_size) input_array = open_ds( f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}" ) diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 778064fc..24ecead7 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -111,13 +111,15 @@ def process( if self.prediction_array._source_data.chunks is not None: block_size = self.prediction_array._source_data.chunks - write_size = Coordinate([ - b * v - for b, v in zip( - block_size[-self.prediction_array.dims :], - self.prediction_array.voxel_size, - ) - ]) + write_size = Coordinate( + [ + b * v + for b, v in zip( + block_size[-self.prediction_array.dims :], + self.prediction_array.voxel_size, + ) + ] + ) output_array = create_from_identifier( output_array_identifier, self.prediction_array.axis_names, diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index 172db065..741e14db 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -314,7 +314,7 @@ def process( channel_dim = True else: raise ValueError("Cannot handle multiple channel dims") - + if not channel_dim: labels = labels[np.newaxis] diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 4a18a315..f2ec4f87 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -49,7 +49,13 @@ class HotDistancePredictor(Predictor): This is a subclass of Predictor. """ - def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool, kernel_size: int): + def __init__( + self, + channels: List[str], + scale_factor: float, + mask_distances: bool, + kernel_size: int, + ): """ Initializes the HotDistancePredictor. diff --git a/dacapo/predict_local.py b/dacapo/predict_local.py index 88aaa5cf..f1760ff9 100644 --- a/dacapo/predict_local.py +++ b/dacapo/predict_local.py @@ -75,8 +75,8 @@ def predict( model_device = str(next(model.parameters()).device).split(":")[0] - assert ( - model_device == str(device) + assert model_device == str( + device ), f"Model is not on the right device, Model: {model_device}, Compute device: {device}" def predict_fn(block): @@ -122,7 +122,7 @@ def predict_fn(block): task = daisy.Task( f"predict_{out_container}_{out_dataset}", total_roi=input_roi, - read_roi=Roi((0,)*input_size.dims, input_size), + read_roi=Roi((0,) * input_size.dims, input_size), write_roi=Roi(context, output_size), process_function=predict_fn, check_function=None, diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py index 2be724d0..e3e569a4 100644 --- a/tests/operations/test_architecture.py +++ b/tests/operations/test_architecture.py @@ -68,7 +68,6 @@ def test_conv_dims( raise ValueError(f"Conv2d found in 3d unet {name}") - @pytest.mark.parametrize( "run_config", [ diff --git a/tests/operations/test_mini.py b/tests/operations/test_mini.py index abc49239..f5070553 100644 --- a/tests/operations/test_mini.py +++ b/tests/operations/test_mini.py @@ -59,7 +59,13 @@ def test_mini( ) task_config = build_test_task_config(task, data_dims, architecture_dims) architecture_config = build_test_architecture_config( - data_dims, architecture_dims, channels, batch_norm, upsample, use_attention, padding + data_dims, + architecture_dims, + channels, + batch_norm, + upsample, + use_attention, + padding, ) run_config = RunConfig( diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index ae8ad176..ad45b848 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -58,4 +58,3 @@ def test_large( training_stats = stats_store.retrieve_training_stats(run_config.name) assert training_stats.trained_until() == run_config.num_iterations - diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 4df49a60..8a4d8cf2 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -38,4 +38,3 @@ def test_large( # test validating weights that don't exist with pytest.raises(FileNotFoundError): validate(run_config.name, 2) -