diff --git a/pysaliency/precomputed_models.py b/pysaliency/precomputed_models.py index 1c3ba08..a24a460 100644 --- a/pysaliency/precomputed_models.py +++ b/pysaliency/precomputed_models.py @@ -101,14 +101,18 @@ def export_model_to_hdf5(model, stimuli, filename, compression=9, overwrite=True mode = 'a' with h5py.File(filename, mode=mode) as f: - for k, s in enumerate(tqdm(stimuli)): - if not overwrite and names[k] in f: - logging.debug(f"Skipping already existing entry {names[k]}") - continue + if overwrite: + indices = range(len(stimuli)) + else: + indices = [i for i in range(len(stimuli)) if names[i] not in f] + logging.debug(f"Skipping {len(stimuli) - len(indices)} already existing entries") + for k in tqdm(indices): + stimulus = stimuli[k] + if isinstance(model, SaliencyMapModel): - smap = model.saliency_map(s) + smap = model.saliency_map(stimulus) elif isinstance(model, Model): - smap = model.log_density(s) + smap = model.log_density(stimulus) else: raise TypeError(type(model)) f.create_dataset(names[k], data=smap, compression=compression)