From a4908faa354ee6a320384601c592f3dc7cd18df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Sun, 22 Sep 2024 13:13:54 +0200 Subject: [PATCH] Better progress indication when exporting models to HDF5 without overwriting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/precomputed_models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pysaliency/precomputed_models.py b/pysaliency/precomputed_models.py index 1c3ba08..a24a460 100644 --- a/pysaliency/precomputed_models.py +++ b/pysaliency/precomputed_models.py @@ -101,14 +101,18 @@ def export_model_to_hdf5(model, stimuli, filename, compression=9, overwrite=True mode = 'a' with h5py.File(filename, mode=mode) as f: - for k, s in enumerate(tqdm(stimuli)): - if not overwrite and names[k] in f: - logging.debug(f"Skipping already existing entry {names[k]}") - continue + if overwrite: + indices = range(len(stimuli)) + else: + indices = [i for i in range(len(stimuli)) if names[i] not in f] + logging.debug(f"Skipping {len(stimuli) - len(indices)} already existing entries") + for k in tqdm(indices): + stimulus = stimuli[k] + if isinstance(model, SaliencyMapModel): - smap = model.saliency_map(s) + smap = model.saliency_map(stimulus) elif isinstance(model, Model): - smap = model.log_density(s) + smap = model.log_density(stimulus) else: raise TypeError(type(model)) f.create_dataset(names[k], data=smap, compression=compression)