diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index be0a94d1..e32276ce 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -49,7 +49,7 @@ def unet_architecture(batch_norm, upsample, use_attention, three_d): name=name, input_shape=(2, 132, 132), eval_shape_increase=(8, 32, 32), - fmaps_in=2, + fmaps_in=1, num_fmaps=8, fmaps_out=8, fmap_inc_factor=2,