Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datasets, handle epoch None (init) better, keep it #1387

Merged
merged 6 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions returnn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def execute_main_task():
assert data, "set forward_data"
else:
data = init_dataset(config.opt_typed_value("forward_data"))
data.init_seq_order(epoch=engine.epoch or 1)
forward_callback = config.typed_value("forward_callback")
assert forward_callback, "no forward_callback specified"
if callable(forward_callback):
Expand All @@ -482,6 +483,7 @@ def execute_main_task():
if config.has("epoch"):
config.set("load_epoch", config.int("epoch", 0))
engine.init_network_from_config(config)
eval_data.init_seq_order(epoch=engine.epoch or 1)
output_file = config.value("output_file", "dump-fwd-epoch-%i.hdf" % engine.epoch)
forward_batch_size = config.int("forward_batch_size", 0)
if not forward_batch_size:
Expand Down
11 changes: 9 additions & 2 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __repr__(self):
)

_getnewargs_exclude_attrs = set() # type: typing.Set[str]
_getnewargs_remap = {} # type: typing.Dict[str,str]

@staticmethod
def _create_from_reduce(cls, kwargs, state) -> Dataset:
Expand All @@ -223,7 +224,9 @@ def __reduce__(self):
for arg in inspect.getargs(cls.__init__.__code__).args[1:]:
if arg in self._getnewargs_exclude_attrs:
continue
if hasattr(self, "_" + arg):
if arg in self._getnewargs_remap:
kwargs[arg] = getattr(self, self._getnewargs_remap[arg])
elif hasattr(self, "_" + arg):
kwargs[arg] = getattr(self, "_" + arg)
else:
kwargs[arg] = getattr(self, arg)
Expand Down Expand Up @@ -447,12 +450,16 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
This is mostly a static method, except that is depends on the configured type of ordering,
such as 'default' (= as-is), 'sorted' or 'random'. 'sorted' also uses the sequence length.

:param int epoch: for 'random', this determines the random seed
:param int|None epoch: for 'random', this determines the random seed
:param int num_seqs:
:param ((int) -> int)|None get_seq_len: function (originalSeqIdx: int) -> int
:return: the order for the given epoch. such that seq_idx -> underlying idx
:rtype: typing.Sequence[int]
"""
if epoch is None:
# This might be called in the beginning. Skip this and wait until we init the real relevant epoch.
# We are not expected to have prepared any real epoch here.
return []
partition_epoch = self.partition_epoch or 1
repeat_epoch = self.repeat_epoch or 1
assert num_seqs > 0
Expand Down
4 changes: 4 additions & 0 deletions returnn/datasets/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def batch_set_generator_cache_whole_epoch(self):
def _init_alloc_intervals(self):
if self.cache_byte_size_limit_at_start == 0:
return
if self.epoch is None:
return
assert self.num_seqs > 0
assert self.num_inputs > 0
assert self.window > 0
Expand Down Expand Up @@ -183,6 +185,8 @@ def _init_start_cache(self):
return
if not self.nbytes:
return
if not self.epoch:
return

num_cached = 0
cached_bytes = 0
Expand Down
2 changes: 0 additions & 2 deletions returnn/datasets/cached2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
Call this when you reset the seq list.
"""
super(CachedDataset2, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
self.expected_load_seq_start = 0
self.reached_final_seq = False
self.added_data = []
Expand Down
15 changes: 6 additions & 9 deletions returnn/datasets/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GeneratingDataset(Dataset):

_input_classes = None
_output_classes = None
_getnewargs_remap = dict(num_seqs="_total_num_seqs", **Dataset._getnewargs_remap)

def __init__(self, input_dim, output_dim, num_seqs=float("inf"), **kwargs):
"""
Expand Down Expand Up @@ -1664,10 +1665,10 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
assert seq_list is None and seq_order is None
super(TimitDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._num_seqs = len(self._seq_tags)
self._seq_order = self.get_seq_order_for_epoch(
epoch=epoch, num_seqs=self._num_seqs, get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0])
epoch=epoch, num_seqs=len(self._seq_tags), get_seq_len=lambda i: len(self._audio_data[self._seq_tags[i]][0])
)
self._num_seqs = len(self._seq_order)
self._random.seed(self._get_random_seed_for_epoch(epoch=epoch))
return True

Expand Down Expand Up @@ -2081,8 +2082,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
import returnn.util.basic

super(LibriSpeechCorpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
random_seed = self._get_random_seed_for_epoch(epoch=epoch)
self._audio_random.seed(random_seed)
if self.targets:
Expand All @@ -2107,7 +2106,7 @@ def get_seq_len(i):
num_seqs = len(self._reference_seq_order)
self._seq_order = self.get_seq_order_for_epoch(epoch=epoch, num_seqs=num_seqs, get_seq_len=get_seq_len)
self._num_seqs = len(self._seq_order)
if self.epoch_wise_filter:
if self.epoch_wise_filter and epoch is not None:
# Note: A more generic variant of this code is :class:`MetaDataset.EpochWiseFilter`.
from .meta import EpochWiseFilter

Expand Down Expand Up @@ -2356,10 +2355,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
:rtype: bool
"""
super(Enwik8Corpus, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1
epoch_part = None
if self.partition_epoch:
if self.partition_epoch and epoch is not None:
epoch_part = (epoch - 1) % self.partition_epoch
epoch = ((epoch - 1) // self.partition_epoch) + 1
self._random.seed(self._get_random_seed_for_epoch(epoch=epoch))
Expand All @@ -2380,7 +2377,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
seq_index = seq_index.transpose()
seq_index = seq_index.flatten()
self._seq_order = seq_index
if self.partition_epoch:
if self.partition_epoch and epoch is not None:
assert self._num_seqs >= self.partition_epoch
partition_epoch_num_seqs = [self._num_seqs // self.partition_epoch] * self.partition_epoch
i = 0
Expand Down
1 change: 0 additions & 1 deletion returnn/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
elif seq_list is not None:
self.seq_order = [self.seq_name_to_idx[s] for s in seq_list]
else:
epoch = epoch or 1
self.seq_order = self.get_seq_order_for_epoch(epoch, len(self.all_seq_names), self._get_seq_length)

def supports_seq_order_sorting(self) -> bool:
Expand Down
6 changes: 1 addition & 5 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
)
self.error_on_invalid_seq = True
super(LmDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1

if seq_order is not None:
self.seq_order = seq_order
Expand All @@ -298,7 +296,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self.num_skipped = 0
self.num_unknown = 0
if self.seq_gen:
self.seq_gen.random_seed(epoch)
self.seq_gen.random_seed(self._get_random_seed_for_epoch(epoch))
return True

def supports_seq_order_sorting(self) -> bool:
Expand Down Expand Up @@ -1458,8 +1456,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
:returns whether the order changed (True is always safe to return)
"""
super(TranslationDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if not epoch:
epoch = 1

if seq_list is None and self.seq_list:
seq_list = self.seq_list
Expand Down
4 changes: 2 additions & 2 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def filter(self, epoch, seq_order, get_seq_len):
:param ((int)->int) get_seq_len: seq idx -> len
:return: new seq_order
"""
epoch = epoch or 1
if epoch is None:
return seq_order
old_num_seqs = len(seq_order)
any_filter = False
for (ep_start, ep_end), value in sorted(self.epochs_opts.items()):
Expand Down Expand Up @@ -388,7 +389,6 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
super(MetaDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
if epoch is None:
# This is called via initialize() with epoch=None, just to init some other things.
self.epoch = None # make sure we properly reinit
# We are not expected to have prepared any real epoch here.
self._num_seqs = 0
return True
Expand Down
7 changes: 4 additions & 3 deletions tests/test_TranslationDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_translation_dataset():

for postfix in ["", " </S>"]: # test with and without postfix

# Replace one word by <UNK>. This way it will not appear in the vocabulary (and <UNK> is added to the vocabulary).
# Replace one word by <UNK>.
# This way it will not appear in the vocabulary (and <UNK> is added to the vocabulary).
# We will test below whether this word is assigned the unknown id by checking whether the reconstruction also
# contains <UNK>. Note, that the input file is already written and contains the original word.
dummy_target_text_with_unk = dummy_target_text.replace("TranslationDatasets", "<UNK>")
Expand All @@ -103,7 +104,7 @@ def test_translation_dataset():
target_postfix=postfix,
unknown_label={"classes": "<UNK>"},
)
translation_dataset.init_seq_order()
translation_dataset.init_seq_order(epoch=1)
translation_dataset.load_seqs(0, 10)

num_seqs = len(dummy_source_text.splitlines())
Expand Down Expand Up @@ -184,7 +185,7 @@ def test_translation_factors_dataset():
target_postfix=postfix,
)

translation_dataset.init_seq_order()
translation_dataset.init_seq_order(epoch=1)
translation_dataset.load_seqs(0, 10)

num_seqs = len(dummy_target_text_factored_format.splitlines())
Expand Down
13 changes: 13 additions & 0 deletions tests/test_torch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _forward_step(*, extern_data: TensorDict, **_kwargs):
)
)
dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})
dataset.init_seq_order(epoch=1)
callback = ForwardCallbackIface()

with global_config_ctx(config):
Expand Down Expand Up @@ -81,6 +82,7 @@ def finish(self):
)
)
dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})
dataset.init_seq_order(epoch=1)
callback = _ForwardCallback()

with global_config_ctx(config):
Expand Down Expand Up @@ -110,6 +112,7 @@ def _forward_step(*, extern_data: TensorDict, **_kwargs):
)
)
dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 100, "name": "dev", "fixed_random_seed": 1})
dataset.init_seq_order(epoch=1)
callback = ForwardCallbackIface()

with global_config_ctx(config):
Expand Down Expand Up @@ -169,6 +172,7 @@ def process_seq(self, *, seq_tag: str, outputs: TensorDict):
assert classes_ == _demo_txt + "."

with global_config_ctx(config), create_ogg_zip_txt_only_dataset(text=_demo_txt, seq_tag=_demo_seq_tag) as dataset:
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
engine.init_network_from_config()
engine.forward_with_callback(callback=_ForwardCallback(), dataset=dataset)
Expand Down Expand Up @@ -226,6 +230,7 @@ def process_seq(self, *, seq_tag: str, outputs: TensorDict):
callback = _ForwardCallback()

with global_config_ctx(config):
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
engine.init_network_from_config()
engine.forward_with_callback(callback=callback, dataset=dataset)
Expand All @@ -238,13 +243,17 @@ def test_min_seq_len():

config = Config({"min_seq_length": 2, "batch_size": 3})
dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=1)
dataset.initialize()
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
data_loader = engine._create_data_loader(dataset)
for _ in data_loader:
assert False, "Should not contain sequences"

config = Config(dict(batch_size=3))
dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=3)
dataset.initialize()
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
data_loader = engine._create_data_loader(dataset)
for _ in data_loader:
Expand All @@ -258,13 +267,17 @@ def test_max_seq_len():

config = Config({"max_seq_length": 4, "batch_size": 3})
dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=5)
dataset.initialize()
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
data_loader = engine._create_data_loader(dataset)
for _ in data_loader:
assert False, "Should not contain sequences"

config = Config(dict(batch_size=3))
dataset = DummyDataset(input_dim=1, output_dim=4, num_seqs=1, seq_len=3)
dataset.initialize()
dataset.init_seq_order(epoch=1)
engine = Engine(config=config)
data_loader = engine._create_data_loader(dataset)
for _ in data_loader:
Expand Down
Loading