From 5780574cc1b1ebad321c02439d9d729c48535e3f Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 27 Nov 2024 05:38:43 -0800 Subject: [PATCH] `PPDataset`: be strict about `seq_order` and `seq_list` in `init_seq_order` (#1652) --- returnn/datasets/postprocessing.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 357e4e0d9..175a1e90c 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -127,6 +127,7 @@ def __init__( self._map_seq_stream = map_seq_stream self._map_outputs = map_outputs self._rng = RandomState(self._get_random_seed_for_epoch(0)) + self._seq_list_for_validation: Optional[List[str]] = None self._dataset = init_dataset(self._dataset_def, parent_dataset=self) if self._map_seq_stream is None: @@ -169,6 +170,12 @@ def init_seq_order( """ super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + if self._map_seq_stream is not None: + if seq_list is not None: + raise ValueError("map_seq_stream is set, cannot specify custom seq_list") + if seq_order is not None: + raise ValueError("map_seq_stream is set, cannot specify custom seq_order") + if epoch is None and seq_list is None and seq_order is None: self._num_seqs = 0 return True @@ -177,6 +184,7 @@ def init_seq_order( assert self._dataset is not None self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) self._data_iter = enumerate(self._build_mapping_iter()) + self._seq_list_for_validation = seq_list if self._map_seq_stream is None: # If we don't have an iterable mapper we know the number of segments exactly # equals the number of segments in the wrapped dataset @@ -260,7 +268,8 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: tensor_dict = self._in_tensor_dict_template.copy_template() for data_key in data_keys: tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) - tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) + seq_tag_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) + tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor if self._map_seq is not None: tensor_dict = self._map_seq( @@ -273,7 +282,14 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: # Re-adding the seq tag here causes no harm in case it's dropped since we don't # add/drop any segments w/ the non-iterator postprocessing function. if "seq_tag" not in tensor_dict.data: - tensor_dict.data["seq_tag"].raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) + tensor_dict.data["seq_tag"].raw_tensor = seq_tag_tensor + + if self._seq_list_for_validation is not None: + seq_tag = self._seq_list_for_validation[seq_index] + tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() + assert ( + tag_of_seq == seq_tag + ), f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" yield tensor_dict seq_index += 1