Skip to content

Commit

Permalink
PPDataset: be strict about seq_order and seq_list in `init_seq_…
Browse files Browse the repository at this point in the history
…order` (#1652)
  • Loading branch information
NeoLegends authored Nov 27, 2024
1 parent 88f3998 commit 5780574
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
self._map_seq_stream = map_seq_stream
self._map_outputs = map_outputs
self._rng = RandomState(self._get_random_seed_for_epoch(0))
self._seq_list_for_validation: Optional[List[str]] = None

self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
if self._map_seq_stream is None:
Expand Down Expand Up @@ -169,6 +170,12 @@ def init_seq_order(
"""
super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)

if self._map_seq_stream is not None:
if seq_list is not None:
raise ValueError("map_seq_stream is set, cannot specify custom seq_list")
if seq_order is not None:
raise ValueError("map_seq_stream is set, cannot specify custom seq_order")

if epoch is None and seq_list is None and seq_order is None:
self._num_seqs = 0
return True
Expand All @@ -177,6 +184,7 @@ def init_seq_order(
assert self._dataset is not None
self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._data_iter = enumerate(self._build_mapping_iter())
self._seq_list_for_validation = seq_list
if self._map_seq_stream is None:
# If we don't have an iterable mapper we know the number of segments exactly
# equals the number of segments in the wrapped dataset
Expand Down Expand Up @@ -260,7 +268,8 @@ def _iterate_dataset(self) -> Iterator[TensorDict]:
tensor_dict = self._in_tensor_dict_template.copy_template()
for data_key in data_keys:
tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
seq_tag_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor

if self._map_seq is not None:
tensor_dict = self._map_seq(
Expand All @@ -273,7 +282,14 @@ def _iterate_dataset(self) -> Iterator[TensorDict]:
# Re-adding the seq tag here causes no harm in case it's dropped since we don't
# add/drop any segments w/ the non-iterator postprocessing function.
if "seq_tag" not in tensor_dict.data:
tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index))
tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor

if self._seq_list_for_validation is not None:
seq_tag = self._seq_list_for_validation[seq_index]
tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item()
assert (
tag_of_seq == seq_tag
), f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given"

yield tensor_dict
seq_index += 1
Expand Down

0 comments on commit 5780574

Please sign in to comment.