diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py index be17777..55b2af2 100644 --- a/mlspm/data_generation.py +++ b/mlspm/data_generation.py @@ -170,7 +170,7 @@ def __init__(self, samples: list[TarSampleList], base_path: PathLike = "./", n_p def __len__(self) -> int: """Total number of samples (including rotations)""" - return sum([len(s["rots"]) for s in self.samples]) + return sum([sum([len(rots) for rots in sample_list["rots"]]) for sample_list in self.samples]) def _launch_procs(self): queue_size = 2 * self.n_proc