Skip to content

Commit

Permalink
MetaDataset, fix get_total_num_seqs, have_seqs
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 2, 2023
1 parent 4433c9f commit e215b19
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
6 changes: 6 additions & 0 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ def get_all_tags(self):
"""
return self.seq_list_original[self.default_dataset_key]

def get_total_num_seqs(self) -> int:
"""
:return: total number of seqs, without partition epoch
"""
return self.num_total_seqs

def finish_epoch(self):
"""
This would get called at the end of the epoch.
Expand Down
45 changes: 37 additions & 8 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,8 @@ def get_seq_len(i):


@contextlib.contextmanager
def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = "sequence0.wav"):
def create_ogg_zip_txt_only_dataset_opts(*, text: str = "hello world", seq_tag: str = "sequence0.wav"):
import zipfile
from returnn.datasets.audio import OggZipDataset

with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip_file, tempfile.NamedTemporaryFile(
suffix=".txt"
Expand All @@ -500,12 +499,21 @@ def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str =
tmp_vocab_file.write(repr(vocab).encode("utf8"))
tmp_vocab_file.flush()

dataset = OggZipDataset(
path=tmp_zip_file.name,
audio=None,
targets={"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": [0]},
)
dataset.initialize()
yield {
"class": "OggZipDataset",
"path": tmp_zip_file.name,
"audio": None,
"targets": {"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": [0]},
}


@contextlib.contextmanager
def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = "sequence0.wav"):
from returnn.datasets.audio import OggZipDataset

with create_ogg_zip_txt_only_dataset_opts(text=text, seq_tag=seq_tag) as opts:
dataset = init_dataset(opts)
assert isinstance(dataset, OggZipDataset)
yield dataset


Expand All @@ -516,6 +524,7 @@ def test_OggZipDataset():

with create_ogg_zip_txt_only_dataset(text=_demo_txt) as dataset:
assert isinstance(dataset, OggZipDataset)
assert dataset.have_seqs()
dataset.init_seq_order(epoch=1)
dataset.load_seqs(0, 1)
raw = dataset.get_data(0, "raw")
Expand All @@ -535,6 +544,26 @@ def test_OggZipDataset():
assert classes_ == _demo_txt + "."


def test_MetaDataset():
_demo_txt = "some utterance text"

with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts:
meta_ds_opts = {
"class": "MetaDataset",
"datasets": {"sub": sub_ds_opts},
"data_map": {"classes": ("sub", "classes")},
"seq_order_control_dataset": "sub",
}
dataset = init_dataset(meta_ds_opts)
assert dataset.have_seqs()
dataset.init_seq_order(epoch=1)
dataset.load_seqs(0, 1)
classes = dataset.get_data(0, "classes")
print("classes:", classes)
assert isinstance(classes, numpy.ndarray) and classes.dtype == numpy.int32 and classes.ndim == 1
assert len(classes) == len(_demo_txt) + 1


def test_MapDatasetWrapper():
from returnn.datasets.map import MapDatasetBase

Expand Down

0 comments on commit e215b19

Please sign in to comment.