diff --git a/returnn/__main__.py b/returnn/__main__.py index 54f6a5b229..ade63b2e9c 100644 --- a/returnn/__main__.py +++ b/returnn/__main__.py @@ -469,6 +469,7 @@ def execute_main_task(): assert data, "set forward_data" else: data = init_dataset(config.opt_typed_value("forward_data")) + data.init_seq_order(epoch=engine.epoch or 1) forward_callback = config.typed_value("forward_callback") assert forward_callback, "no forward_callback specified" if callable(forward_callback): @@ -482,6 +483,7 @@ def execute_main_task(): if config.has("epoch"): config.set("load_epoch", config.int("epoch", 0)) engine.init_network_from_config(config) + eval_data.init_seq_order(epoch=engine.epoch or 1) output_file = config.value("output_file", "dump-fwd-epoch-%i.hdf" % engine.epoch) forward_batch_size = config.int("forward_batch_size", 0) if not forward_batch_size: diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index c35d2327dd..38676938c3 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -198,6 +198,7 @@ def __repr__(self): ) _getnewargs_exclude_attrs = set() # type: typing.Set[str] + _getnewargs_remap = {} # type: typing.Dict[str,str] @staticmethod def _create_from_reduce(cls, kwargs, state) -> Dataset: @@ -223,7 +224,9 @@ def __reduce__(self): for arg in inspect.getargs(cls.__init__.__code__).args[1:]: if arg in self._getnewargs_exclude_attrs: continue - if hasattr(self, "_" + arg): + if arg in self._getnewargs_remap: + kwargs[arg] = getattr(self, self._getnewargs_remap[arg]) + elif hasattr(self, "_" + arg): kwargs[arg] = getattr(self, "_" + arg) else: kwargs[arg] = getattr(self, arg) @@ -447,12 +450,16 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None): This is mostly a static method, except that is depends on the configured type of ordering, such as 'default' (= as-is), 'sorted' or 'random'. 'sorted' also uses the sequence length. - :param int epoch: for 'random', this determines the random seed + :param int|None epoch: for 'random', this determines the random seed :param int num_seqs: :param ((int) -> int)|None get_seq_len: function (originalSeqIdx: int) -> int :return: the order for the given epoch. such that seq_idx -> underlying idx :rtype: typing.Sequence[int] """ + if epoch is None: + # This might be called in the beginning. Skip this and wait until we init the real relevant epoch. + # We are not expected to have prepared any real epoch here. + return [] partition_epoch = self.partition_epoch or 1 repeat_epoch = self.repeat_epoch or 1 assert num_seqs > 0 diff --git a/returnn/datasets/cached.py b/returnn/datasets/cached.py index 1e0d51f17d..6d4a380d73 100644 --- a/returnn/datasets/cached.py +++ b/returnn/datasets/cached.py @@ -152,6 +152,8 @@ def batch_set_generator_cache_whole_epoch(self): def _init_alloc_intervals(self): if self.cache_byte_size_limit_at_start == 0: return + if self.epoch is None: + return assert self.num_seqs > 0 assert self.num_inputs > 0 assert self.window > 0 @@ -183,6 +185,8 @@ def _init_start_cache(self): return if not self.nbytes: return + if not self.epoch: + return num_cached = 0 cached_bytes = 0 diff --git a/returnn/datasets/cached2.py b/returnn/datasets/cached2.py index 34e596281b..6c25b64374 100644 --- a/returnn/datasets/cached2.py +++ b/returnn/datasets/cached2.py @@ -52,8 +52,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): Call this when you reset the seq list. """ super(CachedDataset2, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 self.expected_load_seq_start = 0 self.reached_final_seq = False self.added_data = [] diff --git a/returnn/datasets/generating.py b/returnn/datasets/generating.py index e6bd63f2b0..09c3e27bfc 100644 --- a/returnn/datasets/generating.py +++ b/returnn/datasets/generating.py @@ -26,6 +26,7 @@ class GeneratingDataset(Dataset): _input_classes = None _output_classes = None + _getnewargs_remap = dict(num_seqs="_total_num_seqs", **Dataset._getnewargs_remap) def __init__(self, input_dim, output_dim, num_seqs=float("inf"), **kwargs): """ @@ -1664,10 +1665,10 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ assert seq_list is None and seq_order is None super(TimitDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - self._num_seqs = len(self._seq_tags) self._seq_order = self.get_seq_order_for_epoch( - epoch=epoch, num_seqs=self._num_seqs, get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0]) + epoch=epoch, num_seqs=len(self._seq_tags), get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0]) ) + self._num_seqs = len(self._seq_order) self._random.seed(self._get_random_seed_for_epoch(epoch=epoch)) return True @@ -2081,8 +2082,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): import returnn.util.basic super(LibriSpeechCorpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 random_seed = self._get_random_seed_for_epoch(epoch=epoch) self._audio_random.seed(random_seed) if self.targets: @@ -2107,7 +2106,7 @@ def get_seq_len(i): num_seqs = len(self._reference_seq_order) self._seq_order = self.get_seq_order_for_epoch(epoch=epoch, num_seqs=num_seqs, get_seq_len=get_seq_len) self._num_seqs = len(self._seq_order) - if self.epoch_wise_filter: + if self.epoch_wise_filter and epoch is not None: # Note: A more generic variant of this code is :class:`MetaDataset.EpochWiseFilter`. from .meta import EpochWiseFilter @@ -2356,10 +2355,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): :rtype: bool """ super(Enwik8Corpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 epoch_part = None - if self.partition_epoch: + if self.partition_epoch and epoch is not None: epoch_part = (epoch - 1) % self.partition_epoch epoch = ((epoch - 1) // self.partition_epoch) + 1 self._random.seed(self._get_random_seed_for_epoch(epoch=epoch)) @@ -2380,7 +2377,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): seq_index = seq_index.transpose() seq_index = seq_index.flatten() self._seq_order = seq_index - if self.partition_epoch: + if self.partition_epoch and epoch is not None: assert self._num_seqs >= self.partition_epoch partition_epoch_num_seqs = [self._num_seqs // self.partition_epoch] * self.partition_epoch i = 0 diff --git a/returnn/datasets/hdf.py b/returnn/datasets/hdf.py index f8befcfcd2..551ce0143f 100644 --- a/returnn/datasets/hdf.py +++ b/returnn/datasets/hdf.py @@ -726,7 +726,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): elif seq_list is not None: self.seq_order = [self.seq_name_to_idx[s] for s in seq_list] else: - epoch = epoch or 1 self.seq_order = self.get_seq_order_for_epoch(epoch, len(self.all_seq_names), self._get_seq_length) def supports_seq_order_sorting(self) -> bool: diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index 5038c843d7..ba9a611e1b 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -282,8 +282,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): ) self.error_on_invalid_seq = True super(LmDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 if seq_order is not None: self.seq_order = seq_order @@ -298,7 +296,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): self.num_skipped = 0 self.num_unknown = 0 if self.seq_gen: - self.seq_gen.random_seed(epoch) + self.seq_gen.random_seed(self._get_random_seed_for_epoch(epoch)) return True def supports_seq_order_sorting(self) -> bool: @@ -1458,8 +1456,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): :returns whether the order changed (True is always safe to return) """ super(TranslationDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if not epoch: - epoch = 1 if seq_list is None and self.seq_list: seq_list = self.seq_list diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index b5408b5184..5da60c8535 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -93,7 +93,8 @@ def filter(self, epoch, seq_order, get_seq_len): :param ((int)->int) get_seq_len: seq idx -> len :return: new seq_order """ - epoch = epoch or 1 + if epoch is None: + return seq_order old_num_seqs = len(seq_order) any_filter = False for (ep_start, ep_end), value in sorted(self.epochs_opts.items()): @@ -388,7 +389,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): super(MetaDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) if epoch is None: # This is called via initialize() with epoch=None, just to init some other things. - self.epoch = None # make sure we properly reinit # We are not expected to have prepared any real epoch here. self._num_seqs = 0 return True diff --git a/tests/test_TranslationDataset.py b/tests/test_TranslationDataset.py index 59bba6fc0d..ee4ab88aa3 100644 --- a/tests/test_TranslationDataset.py +++ b/tests/test_TranslationDataset.py @@ -78,7 +78,8 @@ def test_translation_dataset(): for postfix in ["", " "]: # test with and without postfix - # Replace one word by . This way it will not appear in the vocabulary (and is added to the vocabulary). + # Replace one word by . + # This way it will not appear in the vocabulary (and is added to the vocabulary). # We will test below whether this word is assigned the unknown id by checking whether the reconstruction also # contains . Note, that the input file is already written and contains the original word. dummy_target_text_with_unk = dummy_target_text.replace("TranslationDatasets", "") @@ -103,7 +104,7 @@ def test_translation_dataset(): target_postfix=postfix, unknown_label={"classes": ""}, ) - translation_dataset.init_seq_order() + translation_dataset.init_seq_order(epoch=1) translation_dataset.load_seqs(0, 10) num_seqs = len(dummy_source_text.splitlines()) @@ -184,7 +185,7 @@ def test_translation_factors_dataset(): target_postfix=postfix, ) - translation_dataset.init_seq_order() + translation_dataset.init_seq_order(epoch=1) translation_dataset.load_seqs(0, 10) num_seqs = len(dummy_target_text_factored_format.splitlines()) diff --git a/tests/test_torch_engine.py b/tests/test_torch_engine.py index 034caf8546..1e1dbe1b7f 100644 --- a/tests/test_torch_engine.py +++ b/tests/test_torch_engine.py @@ -34,6 +34,7 @@ def _forward_step(*, extern_data: TensorDict, **_kwargs): ) ) dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1}) + dataset.init_seq_order(epoch=1) callback = ForwardCallbackIface() with global_config_ctx(config): @@ -81,6 +82,7 @@ def finish(self): ) ) dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1}) + dataset.init_seq_order(epoch=1) callback = _ForwardCallback() with global_config_ctx(config): @@ -110,6 +112,7 @@ def _forward_step(*, extern_data: TensorDict, **_kwargs): ) ) dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1}) + dataset.init_seq_order(epoch=1) callback = ForwardCallbackIface() with global_config_ctx(config): @@ -169,6 +172,7 @@ def process_seq(self, *, seq_tag: str, outputs: TensorDict): assert classes_ == _demo_txt + "." with global_config_ctx(config), create_ogg_zip_txt_only_dataset(text=_demo_txt, seq_tag=_demo_seq_tag) as dataset: + dataset.init_seq_order(epoch=1) engine = Engine(config=config) engine.init_network_from_config() engine.forward_with_callback(callback=_ForwardCallback(), dataset=dataset) @@ -226,6 +230,7 @@ def process_seq(self, *, seq_tag: str, outputs: TensorDict): callback = _ForwardCallback() with global_config_ctx(config): + dataset.init_seq_order(epoch=1) engine = Engine(config=config) engine.init_network_from_config() engine.forward_with_callback(callback=callback, dataset=dataset) @@ -238,6 +243,8 @@ def test_min_seq_len(): config = Config({"min_seq_length": 2, "batch_size": 3}) dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=1) + dataset.initialize() + dataset.init_seq_order(epoch=1) engine = Engine(config=config) data_loader = engine._create_data_loader(dataset) for _ in data_loader: @@ -245,6 +252,8 @@ def test_min_seq_len(): config = Config(dict(batch_size=3)) dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=3) + dataset.initialize() + dataset.init_seq_order(epoch=1) engine = Engine(config=config) data_loader = engine._create_data_loader(dataset) for _ in data_loader: @@ -258,6 +267,8 @@ def test_max_seq_len(): config = Config({"max_seq_length": 4, "batch_size": 3}) dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=5) + dataset.initialize() + dataset.init_seq_order(epoch=1) engine = Engine(config=config) data_loader = engine._create_data_loader(dataset) for _ in data_loader: @@ -265,6 +276,8 @@ def test_max_seq_len(): config = Config(dict(batch_size=3)) dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=3) + dataset.initialize() + dataset.init_seq_order(epoch=1) engine = Engine(config=config) data_loader = engine._create_data_loader(dataset) for _ in data_loader: