Skip to content

Commit

Permalink
more iterations:
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Nov 14, 2024
1 parent 4a18c42 commit 58a16f2
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_train_non_stored_unet(
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=1,
num_iterations=10,
)
run = Run(run_config)
train_run(run)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_train_unet(
trainer_config=trainer,
datasplit_config=datasplit,
repetition=0,
num_iterations=1,
num_iterations=10,
)
try:
store.store_run_config(run_config)
Expand All @@ -136,7 +136,7 @@ def test_train_unet(
train_run(run)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, 1)
final_weights = weights_store.retrieve_weights(run.name, 10)

for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).any()
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_train_unet(
# trainer_config=trainer,
# datasplit_config=upsample_datasplit,
# repetition=0,
# num_iterations=1,
# num_iterations=10,
# )
# try:
# store.store_run_config(run_config)
Expand All @@ -193,7 +193,7 @@ def test_train_unet(
# # weights_store.store_weights(run, run.train_until)

# init_weights = weights_store.retrieve_weights(run.name, 0)
# final_weights = weights_store.retrieve_weights(run.name, 1)
# final_weights = weights_store.retrieve_weights(run.name, 10)

# for name, weight in init_weights.model.items():
# weight_diff = (weight - final_weights.model[name]).any()
Expand Down

0 comments on commit 58a16f2

Please sign in to comment.