From e255edcbca8d6b53a06d8845eaa7a168fb71dc92 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 29 Nov 2024 15:32:42 +0100 Subject: [PATCH] LmDataset, support get_all_tags, check given seq_list --- returnn/datasets/lm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/returnn/datasets/lm.py b/returnn/datasets/lm.py index f4809f364..fdce9fb05 100644 --- a/returnn/datasets/lm.py +++ b/returnn/datasets/lm.py @@ -445,6 +445,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): if seq_order is not None: self.seq_order = seq_order elif seq_list is not None: + assert all(s.startswith(self._tag_prefix) for s in seq_list) self.seq_order = [int(s[len(self._tag_prefix) :]) for s in seq_list] elif epoch is None: self.seq_order = [] @@ -479,6 +480,11 @@ def get_total_num_seqs(self, *, fast: bool = False) -> int: self._lazy_init() return len(self._orths_offsets_and_lens) + def get_all_tags(self) -> List[str]: + """:return: all seq tags""" + num_seqs = self.get_total_num_seqs() + return [self._tag_prefix + str(line_nr) for line_nr in range(num_seqs)] + def _reduce_log_skipped_seqs(self): if isinstance(self.log_skipped_seqs, bool): return