Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Nov 19, 2024
1 parent aa3b72c commit 4b703fe
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 19 deletions.
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import warnings


class HotDistanceTask(Task):
"""
A class to represent a hot distance task that use binary prediction and distance prediction.
Expand Down
3 changes: 1 addition & 2 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class HotDistanceTaskConfig(TaskConfig):
},
)


kernel_size: int | None = attr.ib(
default=None,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(
overwrite=True,
)

read_roi = Roi((0,)*block_size.dims, block_size)
read_roi = Roi((0,) * block_size.dims, block_size)
input_array = open_ds(
f"{self.prediction_array_identifier.container.path}/{self.prediction_array_identifier.dataset}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def process(
if self.prediction_array._source_data.chunks is not None:
block_size = self.prediction_array._source_data.chunks

write_size = Coordinate([
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
])
write_size = Coordinate(
[
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
]
)
output_array = create_from_identifier(
output_array_identifier,
self.prediction_array.axis_names,
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def process(
channel_dim = True
else:
raise ValueError("Cannot handle multiple channel dims")

if not channel_dim:
labels = labels[np.newaxis]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ class HotDistancePredictor(Predictor):
This is a subclass of Predictor.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool, kernel_size: int):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
kernel_size: int,
):
"""
Initializes the HotDistancePredictor.
Expand Down
6 changes: 3 additions & 3 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def predict(

model_device = str(next(model.parameters()).device).split(":")[0]

assert (
model_device == str(device)
assert model_device == str(
device
), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
Expand Down Expand Up @@ -122,7 +122,7 @@ def predict_fn(block):
task = daisy.Task(
f"predict_{out_container}_{out_dataset}",
total_roi=input_roi,
read_roi=Roi((0,)*input_size.dims, input_size),
read_roi=Roi((0,) * input_size.dims, input_size),
write_roi=Roi(context, output_size),
process_function=predict_fn,
check_function=None,
Expand Down
1 change: 0 additions & 1 deletion tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_conv_dims(
raise ValueError(f"Conv2d found in 3d unet {name}")



@pytest.mark.parametrize(
"run_config",
[
Expand Down
8 changes: 7 additions & 1 deletion tests/operations/test_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ def test_mini(
)
task_config = build_test_task_config(task, data_dims, architecture_dims)
architecture_config = build_test_architecture_config(
data_dims, architecture_dims, channels, batch_norm, upsample, use_attention, padding
data_dims,
architecture_dims,
channels,
batch_norm,
upsample,
use_attention,
padding,
)

run_config = RunConfig(
Expand Down
1 change: 0 additions & 1 deletion tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,3 @@ def test_large(
training_stats = stats_store.retrieve_training_stats(run_config.name)

assert training_stats.trained_until() == run_config.num_iterations

1 change: 0 additions & 1 deletion tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ def test_large(
# test validating weights that don't exist
with pytest.raises(FileNotFoundError):
validate(run_config.name, 2)

0 comments on commit 4b703fe

Please sign in to comment.