From e215b19d5dd80ad2ac118a2032e80b09530fc5f3 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Sat, 2 Sep 2023 21:47:04 +0200 Subject: [PATCH] MetaDataset, fix get_total_num_seqs, have_seqs --- returnn/datasets/meta.py | 6 ++++++ tests/test_Dataset.py | 45 +++++++++++++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index 3444b8139b..b5408b5184 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -455,6 +455,12 @@ def get_all_tags(self): """ return self.seq_list_original[self.default_dataset_key] + def get_total_num_seqs(self) -> int: + """ + :return: total number of seqs, without partition epoch + """ + return self.num_total_seqs + def finish_epoch(self): """ This would get called at the end of the epoch. diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index cfd71b1936..577ca1fe98 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -483,9 +483,8 @@ def get_seq_len(i): @contextlib.contextmanager -def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = "sequence0.wav"): +def create_ogg_zip_txt_only_dataset_opts(*, text: str = "hello world", seq_tag: str = "sequence0.wav"): import zipfile - from returnn.datasets.audio import OggZipDataset with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip_file, tempfile.NamedTemporaryFile( suffix=".txt" @@ -500,12 +499,21 @@ def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = tmp_vocab_file.write(repr(vocab).encode("utf8")) tmp_vocab_file.flush() - dataset = OggZipDataset( - path=tmp_zip_file.name, - audio=None, - targets={"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": [0]}, - ) - dataset.initialize() + yield { + "class": "OggZipDataset", + "path": tmp_zip_file.name, + "audio": None, + "targets": {"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": [0]}, + } + + +@contextlib.contextmanager +def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = "sequence0.wav"): + from returnn.datasets.audio import OggZipDataset + + with create_ogg_zip_txt_only_dataset_opts(text=text, seq_tag=seq_tag) as opts: + dataset = init_dataset(opts) + assert isinstance(dataset, OggZipDataset) yield dataset @@ -516,6 +524,7 @@ def test_OggZipDataset(): with create_ogg_zip_txt_only_dataset(text=_demo_txt) as dataset: assert isinstance(dataset, OggZipDataset) + assert dataset.have_seqs() dataset.init_seq_order(epoch=1) dataset.load_seqs(0, 1) raw = dataset.get_data(0, "raw") @@ -535,6 +544,26 @@ def test_OggZipDataset(): assert classes_ == _demo_txt + "." +def test_MetaDataset(): + _demo_txt = "some utterance text" + + with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts: + meta_ds_opts = { + "class": "MetaDataset", + "datasets": {"sub": sub_ds_opts}, + "data_map": {"classes": ("sub", "classes")}, + "seq_order_control_dataset": "sub", + } + dataset = init_dataset(meta_ds_opts) + assert dataset.have_seqs() + dataset.init_seq_order(epoch=1) + dataset.load_seqs(0, 1) + classes = dataset.get_data(0, "classes") + print("classes:", classes) + assert isinstance(classes, numpy.ndarray) and classes.dtype == numpy.int32 and classes.ndim == 1 + assert len(classes) == len(_demo_txt) + 1 + + def test_MapDatasetWrapper(): from returnn.datasets.map import MapDatasetBase