diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index 38676938c3..5ca536e0fe 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -956,7 +956,9 @@ def have_seqs(self) -> bool: return total_num_seqs > 0 except NotImplementedError: pass - return self.is_less_than_num_seqs(0) + if self.epoch is not None: + return self.is_less_than_num_seqs(0) + raise NotImplementedError(f"{self} have_seqs() is not implemented (and neither get_total_num_seqs())") def len_info(self): """ diff --git a/returnn/datasets/cached2.py b/returnn/datasets/cached2.py index 6c25b64374..929db020dd 100644 --- a/returnn/datasets/cached2.py +++ b/returnn/datasets/cached2.py @@ -98,6 +98,8 @@ def num_seqs(self): """ if self._num_seqs is not None: return self._num_seqs + if self.epoch is None: + return 0 raise NotImplementedError def _load_seqs(self, start, end): @@ -128,6 +130,8 @@ def is_less_than_num_seqs(self, n): """ if n < self.expected_load_seq_start: return True + if self.epoch is None: + return False # noinspection PyBroadException try: return super(CachedDataset2, self).is_less_than_num_seqs(n) diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index ba9a611e1b..7bd8d5853a 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -303,6 +303,10 @@ def supports_seq_order_sorting(self) -> bool: """supports sorting""" return True + def get_total_num_seqs(self) -> int: + """total num seqs""" + return len(self.orths) + def _reduce_log_skipped_seqs(self): if isinstance(self.log_skipped_seqs, bool): return