Skip to content

Commit

Permalink
Datasets, handle epoch None (init) better, keep it
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 2, 2023
1 parent e215b19 commit 6bfeb97
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 20 deletions.
6 changes: 5 additions & 1 deletion returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,16 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
This is mostly a static method, except that is depends on the configured type of ordering,
such as 'default' (= as-is), 'sorted' or 'random'. 'sorted' also uses the sequence length.
:param int epoch: for 'random', this determines the random seed
:param int|None epoch: for 'random', this determines the random seed
:param int num_seqs:
:param ((int) -> int)|None get_seq_len: function (originalSeqIdx: int) -> int
:return: the order for the given epoch. such that seq_idx -> underlying idx
:rtype: typing.Sequence[int]
"""
if epoch is None:
# This might be called in the beginning. Skip this and wait until we init the real relevant epoch.
# We are not expected to have prepared any real epoch here.
return []
partition_epoch = self.partition_epoch or 1
repeat_epoch = self.repeat_epoch or 1
assert num_seqs > 0
Expand Down
2 changes: 0 additions & 2 deletions returnn/datasets/cached2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
Call this when you reset the seq list.
"""
super(CachedDataset2, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
self.expected_load_seq_start = 0
self.reached_final_seq = False
self.added_data = []
Expand Down
14 changes: 5 additions & 9 deletions returnn/datasets/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,10 +1664,10 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
assert seq_list is None and seq_order is None
super(TimitDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._num_seqs = len(self._seq_tags)
self._seq_order = self.get_seq_order_for_epoch(
epoch=epoch, num_seqs=self._num_seqs, get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0])
epoch=epoch, num_seqs=len(self._seq_tags), get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0])
)
self._num_seqs = len(self._seq_order)
self._random.seed(self._get_random_seed_for_epoch(epoch=epoch))
return True

Expand Down Expand Up @@ -2081,8 +2081,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
import returnn.util.basic

super(LibriSpeechCorpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
random_seed = self._get_random_seed_for_epoch(epoch=epoch)
self._audio_random.seed(random_seed)
if self.targets:
Expand All @@ -2107,7 +2105,7 @@ def get_seq_len(i):
num_seqs = len(self._reference_seq_order)
self._seq_order = self.get_seq_order_for_epoch(epoch=epoch, num_seqs=num_seqs, get_seq_len=get_seq_len)
self._num_seqs = len(self._seq_order)
if self.epoch_wise_filter:
if self.epoch_wise_filter and epoch is not None:
# Note: A more generic variant of this code is :class:`MetaDataset.EpochWiseFilter`.
from .meta import EpochWiseFilter

Expand Down Expand Up @@ -2356,10 +2354,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
:rtype: bool
"""
super(Enwik8Corpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
epoch_part = None
if self.partition_epoch:
if self.partition_epoch and epoch is not None:
epoch_part = (epoch - 1) % self.partition_epoch
epoch = ((epoch - 1) // self.partition_epoch) + 1
self._random.seed(self._get_random_seed_for_epoch(epoch=epoch))
Expand All @@ -2380,7 +2376,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
seq_index = seq_index.transpose()
seq_index = seq_index.flatten()
self._seq_order = seq_index
if self.partition_epoch:
if self.partition_epoch and epoch is not None:
assert self._num_seqs >= self.partition_epoch
partition_epoch_num_seqs = [self._num_seqs // self.partition_epoch] * self.partition_epoch
i = 0
Expand Down
1 change: 0 additions & 1 deletion returnn/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
elif seq_list is not None:
self.seq_order = [self.seq_name_to_idx[s] for s in seq_list]
else:
epoch = epoch or 1
self.seq_order = self.get_seq_order_for_epoch(epoch, len(self.all_seq_names), self._get_seq_length)

def supports_seq_order_sorting(self) -> bool:
Expand Down
6 changes: 1 addition & 5 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
)
self.error_on_invalid_seq = True
super(LmDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1

if seq_order is not None:
self.seq_order = seq_order
Expand All @@ -298,7 +296,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self.num_skipped = 0
self.num_unknown = 0
if self.seq_gen:
self.seq_gen.random_seed(epoch)
self.seq_gen.random_seed(self._get_random_seed_for_epoch(epoch))
return True

def supports_seq_order_sorting(self) -> bool:
Expand Down Expand Up @@ -1458,8 +1456,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
:returns whether the order changed (True is always safe to return)
"""
super(TranslationDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1

if seq_list is None and self.seq_list:
seq_list = self.seq_list
Expand Down
4 changes: 2 additions & 2 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def filter(self, epoch, seq_order, get_seq_len):
:param ((int)->int) get_seq_len: seq idx -> len
:return: new seq_order
"""
epoch = epoch or 1
if epoch is None:
return seq_order
old_num_seqs = len(seq_order)
any_filter = False
for (ep_start, ep_end), value in sorted(self.epochs_opts.items()):
Expand Down Expand Up @@ -388,7 +389,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
super(MetaDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if epoch is None:
# This is called via initialize() with epoch=None, just to init some other things.
self.epoch = None # make sure we properly reinit
# We are not expected to have prepared any real epoch here.
self._num_seqs = 0
return True
Expand Down

0 comments on commit 6bfeb97

Please sign in to comment.