Skip to content

Commit

Permalink
🎨 Format Python code with psf/black
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Nov 14, 2024
1 parent 58a16f2 commit 0979152
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 8 deletions.
5 changes: 3 additions & 2 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def predict(

model_device = str(next(model.parameters()).device).split(":")[0]

assert model_device == str(device), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"
assert model_device == str(
device
), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
raw_input = raw_array.to_ndarray(block.read_roi)
Expand All @@ -89,7 +91,6 @@ def predict_fn(block):
raw_input = np.expand_dims(raw_input, 0)
axis_names = ["c^"] + axis_names


with torch.no_grad():
model.eval()
predictions = (
Expand Down
8 changes: 7 additions & 1 deletion tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from .losses import dummy_loss
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run, unet_2d_distance_run, unet_3d_distance_run
from .runs import (
dummy_run,
distance_run,
onehot_run,
unet_2d_distance_run,
unet_3d_distance_run,
)
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task
from .trainers import dummy_trainer, gunpowder_trainer
4 changes: 2 additions & 2 deletions tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def unet_architecture_builder(batch_norm, upsample, use_attention, three_d):
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(4, 4), ( 4, 4)],
kernel_size_down=[[( 3, 3)] * 2] * 3,
downsample_factors=[(4, 4), (4, 4)],
kernel_size_down=[[(3, 3)] * 2] * 3,
kernel_size_up=[[(3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def unet_2d_distance_run(
)



@pytest.fixture()
def unet_3d_distance_run(
six_class_datasplit,
Expand Down
1 change: 0 additions & 1 deletion tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_2d_conv_unet(
raise ValueError(f"Conv2d found in 3d unet {name}")



@pytest.mark.parametrize(
"run_config",
[
Expand Down
1 change: 0 additions & 1 deletion tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def test_train_unet(
batch_norm, upsample, use_attention, three_d
)


run_config = RunConfig(
name=f"{architecture_config.name}_run",
task_config=task,
Expand Down

0 comments on commit 0979152

Please sign in to comment.