diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index c35d2327dd..60fe6fa7d3 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -447,12 +447,16 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None): This is mostly a static method, except that is depends on the configured type of ordering, such as 'default' (= as-is), 'sorted' or 'random'. 'sorted' also uses the sequence length. - :param int epoch: for 'random', this determines the random seed + :param int|None epoch: for 'random', this determines the random seed :param int num_seqs: :param ((int) -> int)|None get_seq_len: function (originalSeqIdx: int) -> int :return: the order for the given epoch. such that seq_idx -> underlying idx :rtype: typing.Sequence[int] """ + if epoch is None: + # This might be called in the beginning. Skip this and wait until we init the real relevant epoch. + # We are not expected to have prepared any real epoch here. + return [] partition_epoch = self.partition_epoch or 1 repeat_epoch = self.repeat_epoch or 1 assert num_seqs > 0 diff --git a/returnn/datasets/cached2.py b/returnn/datasets/cached2.py index 34e596281b..6c25b64374 100644 --- a/returnn/datasets/cached2.py +++ b/returnn/datasets/cached2.py @@ -52,8 +52,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): Call this when you reset the seq list. """ super(CachedDataset2, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 self.expected_load_seq_start = 0 self.reached_final_seq = False self.added_data = [] diff --git a/returnn/datasets/generating.py b/returnn/datasets/generating.py index e6bd63f2b0..1029738f38 100644 --- a/returnn/datasets/generating.py +++ b/returnn/datasets/generating.py @@ -1664,10 +1664,10 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ assert seq_list is None and seq_order is None super(TimitDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - self._num_seqs = len(self._seq_tags) self._seq_order = self.get_seq_order_for_epoch( - epoch=epoch, num_seqs=self._num_seqs, get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0]) + epoch=epoch, num_seqs=len(self._seq_tags), get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0]) ) + self._num_seqs = len(self._seq_order) self._random.seed(self._get_random_seed_for_epoch(epoch=epoch)) return True @@ -2081,8 +2081,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): import returnn.util.basic super(LibriSpeechCorpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 random_seed = self._get_random_seed_for_epoch(epoch=epoch) self._audio_random.seed(random_seed) if self.targets: @@ -2107,7 +2105,7 @@ def get_seq_len(i): num_seqs = len(self._reference_seq_order) self._seq_order = self.get_seq_order_for_epoch(epoch=epoch, num_seqs=num_seqs, get_seq_len=get_seq_len) self._num_seqs = len(self._seq_order) - if self.epoch_wise_filter: + if self.epoch_wise_filter and epoch is not None: # Note: A more generic variant of this code is :class:`MetaDataset.EpochWiseFilter`. from .meta import EpochWiseFilter @@ -2356,10 +2354,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): :rtype: bool """ super(Enwik8Corpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 epoch_part = None - if self.partition_epoch: + if self.partition_epoch and epoch is not None: epoch_part = (epoch - 1) % self.partition_epoch epoch = ((epoch - 1) // self.partition_epoch) + 1 self._random.seed(self._get_random_seed_for_epoch(epoch=epoch)) @@ -2380,7 +2376,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): seq_index = seq_index.transpose() seq_index = seq_index.flatten() self._seq_order = seq_index - if self.partition_epoch: + if self.partition_epoch and epoch is not None: assert self._num_seqs >= self.partition_epoch partition_epoch_num_seqs = [self._num_seqs // self.partition_epoch] * self.partition_epoch i = 0 diff --git a/returnn/datasets/hdf.py b/returnn/datasets/hdf.py index f8befcfcd2..551ce0143f 100644 --- a/returnn/datasets/hdf.py +++ b/returnn/datasets/hdf.py @@ -726,7 +726,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): elif seq_list is not None: self.seq_order = [self.seq_name_to_idx[s] for s in seq_list] else: - epoch = epoch or 1 self.seq_order = self.get_seq_order_for_epoch(epoch, len(self.all_seq_names), self._get_seq_length) def supports_seq_order_sorting(self) -> bool: diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index 5038c843d7..ba9a611e1b 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -282,8 +282,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): ) self.error_on_invalid_seq = True super(LmDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 if seq_order is not None: self.seq_order = seq_order @@ -298,7 +296,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): self.num_skipped = 0 self.num_unknown = 0 if self.seq_gen: - self.seq_gen.random_seed(epoch) + self.seq_gen.random_seed(self._get_random_seed_for_epoch(epoch)) return True def supports_seq_order_sorting(self) -> bool: @@ -1458,8 +1456,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): :returns whether the order changed (True is always safe to return) """ super(TranslationDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 if seq_list is None and self.seq_list: seq_list = self.seq_list diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index b5408b5184..5da60c8535 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -93,7 +93,8 @@ def filter(self, epoch, seq_order, get_seq_len): :param ((int)->int) get_seq_len: seq idx -> len :return: new seq_order """ - epoch = epoch or 1 + if epoch is None: + return seq_order old_num_seqs = len(seq_order) any_filter = False for (ep_start, ep_end), value in sorted(self.epochs_opts.items()): @@ -388,7 +389,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): super(MetaDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) if epoch is None: # This is called via initialize() with epoch=None, just to init some other things. - self.epoch = None # make sure we properly reinit # We are not expected to have prepared any real epoch here. self._num_seqs = 0 return True