Skip to content

Commit

Permalink
LmDataset get_total_num_seqs (#1442)
Browse files Browse the repository at this point in the history
Should fix #1441.

Also:

CachedDataset2 num_seqs 0 if not initialized.
CachedDataset2 is_less_than_num_seqs same.

Dataset have_seqs more correct.
  • Loading branch information
albertz authored Oct 21, 2023
1 parent a506073 commit e3c0303
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,9 @@ def have_seqs(self) -> bool:
return total_num_seqs > 0
except NotImplementedError:
pass
return self.is_less_than_num_seqs(0)
if self.epoch is not None:
return self.is_less_than_num_seqs(0)
raise NotImplementedError(f"{self} have_seqs() is not implemented (and neither get_total_num_seqs())")

def len_info(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions returnn/datasets/cached2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def num_seqs(self):
"""
if self._num_seqs is not None:
return self._num_seqs
if self.epoch is None:
return 0
raise NotImplementedError

def _load_seqs(self, start, end):
Expand Down Expand Up @@ -128,6 +130,8 @@ def is_less_than_num_seqs(self, n):
"""
if n < self.expected_load_seq_start:
return True
if self.epoch is None:
return False
# noinspection PyBroadException
try:
return super(CachedDataset2, self).is_less_than_num_seqs(n)
Expand Down
4 changes: 4 additions & 0 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def supports_seq_order_sorting(self) -> bool:
"""supports sorting"""
return True

def get_total_num_seqs(self) -> int:
"""total num seqs"""
return len(self.orths)

def _reduce_log_skipped_seqs(self):
if isinstance(self.log_skipped_seqs, bool):
return
Expand Down

0 comments on commit e3c0303

Please sign in to comment.