Skip to content

Commit

Permalink
PostprocessingDataset, provide epoch (and seq_idx in map_seq)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 26, 2024
1 parent 13640bc commit 88f3998
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi

data_iter = self._iterate_dataset()
if self._map_seq_stream is not None:
data_iter = self._map_seq_stream(data_iter, rng=self._rng, **util.get_fwd_compat_kwargs())
data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs())
assert isinstance(
data_iter, Iterator
), f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
Expand All @@ -263,7 +263,9 @@ def _iterate_dataset(self) -> Iterator[TensorDict]:
tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))

if self._map_seq is not None:
tensor_dict = self._map_seq(tensor_dict, rng=self._rng, **util.get_fwd_compat_kwargs())
tensor_dict = self._map_seq(
tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs()
)
assert isinstance(
tensor_dict, TensorDict
), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
Expand Down

0 comments on commit 88f3998

Please sign in to comment.