From 1abb3fdff8af15608e07db84581673c8909124c0 Mon Sep 17 00:00:00 2001 From: mzouink Date: Thu, 14 Nov 2024 14:49:56 +0000 Subject: [PATCH] :art: Format Python code with psf/black --- tests/fixtures/architectures.py | 21 ++++++++++----------- tests/fixtures/tasks.py | 3 ++- tests/operations/test_validate.py | 4 +++- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/fixtures/architectures.py b/tests/fixtures/architectures.py index a75a7764..71b2251d 100644 --- a/tests/fixtures/architectures.py +++ b/tests/fixtures/architectures.py @@ -31,20 +31,19 @@ def unet_architecture(): ) - @pytest.fixture() def unet_3d_architecture(): yield CNNectomeUNetConfig( - name="tmp_unet_3d_architecture", - input_shape=(188, 188, 188), - eval_shape_increase=(72, 72, 72), - fmaps_in=1, - num_fmaps=6, - fmaps_out=6, - fmap_inc_factor=2, - downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], - constant_upsample=True, - ) + name="tmp_unet_3d_architecture", + input_shape=(188, 188, 188), + eval_shape_increase=(72, 72, 72), + fmaps_in=1, + num_fmaps=6, + fmaps_out=6, + fmap_inc_factor=2, + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=True, + ) def unet_architecture_builder(batch_norm, upsample, use_attention, three_d): diff --git a/tests/fixtures/tasks.py b/tests/fixtures/tasks.py index 3f91106b..4230fd9b 100644 --- a/tests/fixtures/tasks.py +++ b/tests/fixtures/tasks.py @@ -35,9 +35,10 @@ def onehot_task(): classes=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"], ) + @pytest.fixture() def six_onehot_task(): yield OneHotTaskConfig( name="one_hot_task", classes=["a", "b", "c", "d", "e", "f"], - ) \ No newline at end of file + ) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index abec46ea..1ae80cd0 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -104,7 +104,9 @@ def test_validate_run( @pytest.mark.parametrize("datasplit", [lf("six_class_datasplit")]) @pytest.mark.parametrize("task", [lf("distance_task"), lf("six_onehot_task")]) @pytest.mark.parametrize("trainer", [lf("gunpowder_trainer")]) -@pytest.mark.parametrize("architecture", [lf("unet_architecture"), lf("unet_3d_architecture")]) +@pytest.mark.parametrize( + "architecture", [lf("unet_architecture"), lf("unet_3d_architecture")] +) def test_validate_unet(datasplit, task, trainer, architecture): store = create_config_store() weights_store = create_weights_store()