Skip to content

Commit

Permalink
SimpleHDFWriter: add extra_labels (#1653)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorian-K authored Nov 26, 2024
1 parent 920f859 commit 13640bc
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions returnn/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,17 @@ class SimpleHDFWriter:
Note that we dump to a temp file first, and only at :func:`close` we move it over to the real destination.
"""

def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr=False, extend_existing_file=False):
def __init__(
self,
filename,
dim,
labels=None,
ndim=None,
extra_type=None,
swmr=False,
extend_existing_file=False,
extra_labels=None,
):
"""
:param str filename: Create file, truncate if exists
:param int|None dim:
Expand All @@ -1084,6 +1094,7 @@ def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr=
:param dict[str,(int,int,str)]|None extra_type: key -> (dim,ndim,dtype)
:param bool swmr: see https://docs.h5py.org/en/stable/swmr.html
:param bool extend_existing_file: True also means we expect that it exists
:param dict[str,list[str]]|None extra_labels: key -> labels
"""
from returnn.util.basic import hdf5_strings, unicode
import tempfile
Expand Down Expand Up @@ -1147,7 +1158,7 @@ def __init__(self, filename, dim, labels=None, ndim=None, extra_type=None, swmr=
self._extra_num_time_steps = {} # type: typing.Dict[str,int] # key -> num-steps
self._prepared_extra = set()
if extra_type:
self._prepare_extra(extra_type)
self._prepare_extra(extra_type, extra_labels if extra_labels else {})

if swmr:
assert not self._file.swmr_mode # this also checks whether the attribute exists (right version)
Expand All @@ -1160,7 +1171,7 @@ def __del__(self):
self._file.close()
self._file = None

def _prepare_extra(self, extra_type):
def _prepare_extra(self, extra_type, extra_labels):
"""
:param dict[str,(int,int,str)] extra_type: key -> (dim,ndim,dtype)
:return: whether we added a new entry
Expand All @@ -1178,7 +1189,11 @@ def _prepare_extra(self, extra_type):
self._file.create_group("targets/data")
self._file.create_group("targets/size")
self._file.create_group("targets/labels")
hdf5_strings(self._file, "targets/labels/%s" % data_key, ["dummy-label"])
labels = ["dummy-label"]
if data_key in extra_labels:
labels = extra_labels[data_key]
assert len(labels) == dim
hdf5_strings(self._file, "targets/labels/%s" % data_key, labels)
if ndim == 0:
ndim = 1 # we will automatically add a dummy-dim
shape = [None] * ndim # type: typing.List[typing.Optional[int]]
Expand Down Expand Up @@ -1261,7 +1276,7 @@ def _insert_h5_other(self, data_key, raw_data, dtype=None, add_time_dim=False, d
dtype = "string"
else:
dtype = raw_data.dtype.name
if self._prepare_extra({data_key: (dim, raw_data.ndim, dtype)}):
if self._prepare_extra({data_key: (dim, raw_data.ndim, dtype)}, {}):
# We added it now. Maybe other extra data keys were added before. The data_key_idx is different now.
# Thus, seq_lengths might have become invalid. Reinit them.
assert seq_idx == 0 or self.extend_existing_file # We can only do that in the beginning.
Expand Down

0 comments on commit 13640bc

Please sign in to comment.