diff --git a/returnn/datasets/hdf.py b/returnn/datasets/hdf.py index e0e3cc7a8..5a1e9383e 100644 --- a/returnn/datasets/hdf.py +++ b/returnn/datasets/hdf.py @@ -1075,7 +1075,17 @@ class SimpleHDFWriter: Note that we dump to a temp file first, and only at :func:`close` we move it over to the real destination. """ - def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr=False, extend_existing_file=False): + def __init__( + self, + filename, + dim, + labels=None, + ndim=None, + extra_type=None, + swmr=False, + extend_existing_file=False, + extra_labels=None, + ): """ :param str filename: Create file, truncate if exists :param int|None dim: @@ -1084,6 +1094,7 @@ def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr= :param dict[str,(int,int,str)]|None extra_type: key -> (dim,ndim,dtype) :param bool swmr: see https://docs.h5py.org/en/stable/swmr.html :param bool extend_existing_file: True also means we expect that it exists + :param dict[str,list[str]]|None extra_labels: key -> labels """ from returnn.util.basic import hdf5_strings, unicode import tempfile @@ -1147,7 +1158,7 @@ def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr= self._extra_num_time_steps = {} # type: typing.Dict[str,int] # key -> num-steps self._prepared_extra = set() if extra_type: - self._prepare_extra(extra_type) + self._prepare_extra(extra_type, extra_labels if extra_labels else {}) if swmr: assert not self._file.swmr_mode # this also checks whether the attribute exists (right version) @@ -1160,7 +1171,7 @@ def __del__(self): self._file.close() self._file = None - def _prepare_extra(self, extra_type): + def _prepare_extra(self, extra_type, extra_labels): """ :param dict[str,(int,int,str)] extra_type: key -> (dim,ndim,dtype) :return: whether we added a new entry @@ -1178,7 +1189,11 @@ def _prepare_extra(self, extra_type): self._file.create_group("targets/data") self._file.create_group("targets/size") self._file.create_group("targets/labels") - hdf5_strings(self._file, "targets/labels/%s" % data_key, ["dummy-label"]) + labels = ["dummy-label"] + if data_key in extra_labels: + labels = extra_labels[data_key] + assert len(labels) == dim + hdf5_strings(self._file, "targets/labels/%s" % data_key, labels) if ndim == 0: ndim = 1 # we will automatically add a dummy-dim shape = [None] * ndim # type: typing.List[typing.Optional[int]] @@ -1261,7 +1276,7 @@ def _insert_h5_other(self, data_key, raw_data, dtype=None, add_time_dim=False, d dtype = "string" else: dtype = raw_data.dtype.name - if self._prepare_extra({data_key: (dim, raw_data.ndim, dtype)}): + if self._prepare_extra({data_key: (dim, raw_data.ndim, dtype)}, {}): # We added it now. Maybe other extra data keys were added before. The data_key_idx is different now. # Thus, seq_lengths might have become invalid. Reinit them. assert seq_idx == 0 or self.extend_existing_file # We can only do that in the beginning.