From 762b9002bc1b5b40e9a86733917969395e244cc3 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Wed, 6 Dec 2023 16:54:27 +0100 Subject: [PATCH] cleaned train_npe --- src/cryo_sbi/inference/train_npe_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 732a08b..612f990 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -17,6 +17,7 @@ from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator from cryo_sbi.wpa_simulator.validate_image_config import check_image_params from cryo_sbi.inference.validate_train_config import check_train_params +import cryo_sbi.utils.image_utils as img_utils def load_model( @@ -121,7 +122,7 @@ def npe_train_no_saving( ) step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) mean_loss = [] - + print("Training neural netowrk:") estimator.train() with tqdm(range(epochs), unit="epoch") as tq: