Skip to content

Commit

Permalink
LmDataset, support get_all_tags, check given seq_list
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 29, 2024
1 parent 882e00c commit e255edc
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
if seq_order is not None:
self.seq_order = seq_order
elif seq_list is not None:
assert all(s.startswith(self._tag_prefix) for s in seq_list)
self.seq_order = [int(s[len(self._tag_prefix) :]) for s in seq_list]
elif epoch is None:
self.seq_order = []
Expand Down Expand Up @@ -479,6 +480,11 @@ def get_total_num_seqs(self, *, fast: bool = False) -> int:
self._lazy_init()
return len(self._orths_offsets_and_lens)

def get_all_tags(self) -> List[str]:
""":return: all seq tags"""
num_seqs = self.get_total_num_seqs()
return [self._tag_prefix + str(line_nr) for line_nr in range(num_seqs)]

def _reduce_log_skipped_seqs(self):
if isinstance(self.log_skipped_seqs, bool):
return
Expand Down

0 comments on commit e255edc

Please sign in to comment.