diff --git a/dacapo/predict_local.py b/dacapo/predict_local.py index d829dffd..a4f21ab1 100644 --- a/dacapo/predict_local.py +++ b/dacapo/predict_local.py @@ -73,7 +73,9 @@ def predict( model_device = str(next(model.parameters()).device).split(":")[0] - assert model_device == str(device), f"Model is not on the right device, Model: {model_device}, Compute device: {device}" + assert model_device == str( + device + ), f"Model is not on the right device, Model: {model_device}, Compute device: {device}" def predict_fn(block): raw_input = raw_array.to_ndarray(block.read_roi) @@ -89,7 +91,6 @@ def predict_fn(block): raw_input = np.expand_dims(raw_input, 0) axis_names = ["c^"] + axis_names - with torch.no_grad(): model.eval() predictions = ( diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index d4186d29..1c190847 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -16,6 +16,12 @@ from .losses import dummy_loss from .post_processors import argmax, threshold from .predictors import distance_predictor, onehot_predictor -from .runs import dummy_run, distance_run, onehot_run, unet_2d_distance_run, unet_3d_distance_run +from .runs import ( + dummy_run, + distance_run, + onehot_run, + unet_2d_distance_run, + unet_3d_distance_run, +) from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task from .trainers import dummy_trainer, gunpowder_trainer diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index 558dea56..4baca2da 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -76,8 +76,8 @@ def unet_architecture_builder(batch_norm, upsample, use_attention, three_d): num_fmaps=8, fmaps_out=8, fmap_inc_factor=2, - downsample_factors=[(4, 4), ( 4, 4)], - kernel_size_down=[[( 3, 3)] * 2] * 3, + downsample_factors=[(4, 4), (4, 4)], + kernel_size_down=[[(3, 3)] * 2] * 3, kernel_size_up=[[(3, 3)] * 2] * 2, constant_upsample=True, padding="valid", diff --git a/tests/fixtures/runs.py b/tests/fixtures/runs.py index d5e584f9..b508079b 100644 --- a/tests/fixtures/runs.py +++ b/tests/fixtures/runs.py @@ -75,7 +75,6 @@ def unet_2d_distance_run( ) - @pytest.fixture() def unet_3d_distance_run( six_class_datasplit, diff --git a/tests/operations/test_architecture.py b/tests/operations/test_architecture.py index 55ddf765..1969ce33 100644 --- a/tests/operations/test_architecture.py +++ b/tests/operations/test_architecture.py @@ -82,7 +82,6 @@ def test_2d_conv_unet( raise ValueError(f"Conv2d found in 3d unet {name}") - @pytest.mark.parametrize( "run_config", [ diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index 441c4289..bbb0ac22 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -110,7 +110,6 @@ def test_train_unet( batch_norm, upsample, use_attention, three_d ) - run_config = RunConfig( name=f"{architecture_config.name}_run", task_config=task,