diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index ebb6d4c3..441c4289 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -86,7 +86,7 @@ def test_train_non_stored_unet( trainer_config=trainer, datasplit_config=datasplit, repetition=0, - num_iterations=1, + num_iterations=10, ) run = Run(run_config) train_run(run) @@ -118,7 +118,7 @@ def test_train_unet( trainer_config=trainer, datasplit_config=datasplit, repetition=0, - num_iterations=1, + num_iterations=10, ) try: store.store_run_config(run_config) @@ -136,7 +136,7 @@ def test_train_unet( train_run(run) init_weights = weights_store.retrieve_weights(run.name, 0) - final_weights = weights_store.retrieve_weights(run.name, 1) + final_weights = weights_store.retrieve_weights(run.name, 10) for name, weight in init_weights.model.items(): weight_diff = (weight - final_weights.model[name]).any() @@ -174,7 +174,7 @@ def test_train_unet( # trainer_config=trainer, # datasplit_config=upsample_datasplit, # repetition=0, -# num_iterations=1, +# num_iterations=10, # ) # try: # store.store_run_config(run_config) @@ -193,7 +193,7 @@ def test_train_unet( # # weights_store.store_weights(run, run.train_until) # init_weights = weights_store.retrieve_weights(run.name, 0) -# final_weights = weights_store.retrieve_weights(run.name, 1) +# final_weights = weights_store.retrieve_weights(run.name, 10) # for name, weight in init_weights.model.items(): # weight_diff = (weight - final_weights.model[name]).any()