diff --git a/returnn/datasets/sprint.py b/returnn/datasets/sprint.py index 475a094234..0d46f5279e 100644 --- a/returnn/datasets/sprint.py +++ b/returnn/datasets/sprint.py @@ -23,8 +23,9 @@ from returnn.datasets.basic import Dataset, DatasetSeq from .cached2 import CachedDataset2 from returnn.log import log -from returnn.util.task_system import Unpickler, numpy_copy_and_set_unused -from returnn.util.basic import eval_shell_str, interrupt_main, unicode, PY3, BytesIO, close_all_fds_except +from returnn.util.task_system import numpy_copy_and_set_unused +from returnn.util.basic import eval_shell_str, interrupt_main, unicode, PY3, close_all_fds_except +import returnn.util.basic as util class SprintDatasetBase(Dataset): @@ -913,31 +914,9 @@ def _read_next_raw(self): :return: (data_type, args) :rtype: (str, object) """ - import struct - - size_raw = self.pipe_c2p[0].read(4) - if len(size_raw) < 4: - raise EOFError - (size,) = struct.unpack(" 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,) - stream = BytesIO() - read_size = 0 - while read_size < size: - data_raw = self.pipe_c2p[0].read(size - read_size) - if len(data_raw) == 0: - raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size)) - read_size += len(data_raw) - stream.write(data_raw) - stream.seek(0) - try: - if PY3: - # encoding is for converting Python2 strings to Python3. - # Cannot use utf8 because Numpy will also encode the data as strings and there we need it as bytes. - data_type, args = Unpickler(stream, encoding="bytes").load() - else: - data_type, args = Unpickler(stream).load() - except EOFError: - raise Exception("%s: parse error of %i bytes (%r)" % (self, size, stream.getvalue())) + # encoding is for converting Python2 strings to Python3. + # Cannot use utf8 because Numpy will also encode the data as strings and there we need it as bytes. + data_type, args = util.read_pickled_object(self.pipe_c2p[0], encoding="bytes") return data_type, args def _join_child(self, wait=True, expected_exit_status=None): diff --git a/returnn/sprint/control.py b/returnn/sprint/control.py index a3bd447dd2..3a4f43779f 100644 --- a/returnn/sprint/control.py +++ b/returnn/sprint/control.py @@ -15,15 +15,14 @@ import os import numpy import typing -import io -import struct from threading import Condition import returnn.__main__ as rnn import returnn.util.task_system as task_system import returnn.util.debug as debug -from returnn.util.task_system import Pickler, Unpickler, numpy_set_unused +from returnn.util.task_system import numpy_set_unused from returnn.util.basic import to_bool, long +import returnn.util.basic as util InitTypes = set() Verbose = False # disables all per-segment log messages @@ -434,31 +433,10 @@ def _get_loss_and_error_signal_via_sprint_callback(self, seg_name, orthography, return loss, error_signal def _send(self, data): - stream = io.BytesIO() - Pickler(stream).dump(data) - raw_data = stream.getvalue() - assert len(raw_data) > 0 - self.pipe_c2p.write(struct.pack(" 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,) - stream = io.BytesIO() - read_size = 0 - while read_size < size: - data_raw = p.read(size - read_size) - if len(data_raw) == 0: - raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size)) - read_size += len(data_raw) - stream.write(data_raw) - stream.seek(0) - return Unpickler(stream).load() + return util.read_pickled_object(self.pipe_p2c) def close(self): """ diff --git a/returnn/sprint/error_signals.py b/returnn/sprint/error_signals.py index ab7d1cb040..bc4de91996 100644 --- a/returnn/sprint/error_signals.py +++ b/returnn/sprint/error_signals.py @@ -14,12 +14,11 @@ import atexit import signal import typing -import struct -import io from threading import RLock, Thread import returnn.util.task_system as task_system -from returnn.util.task_system import Pickler, Unpickler, numpy_set_unused +from returnn.util.task_system import numpy_set_unused from returnn.util.basic import eval_shell_str, make_hashable, close_all_fds_except +import returnn.util.basic as util from returnn.log import log @@ -219,32 +218,12 @@ def _build_sprint_args(self): def _send(self, v): assert os.getpid() == self.parent_pid p = self.pipe_p2c[1] # see _start_child - stream = io.BytesIO() - Pickler(stream).dump(v) - raw_data = stream.getvalue() - assert len(raw_data) > 0 - p.write(struct.pack(" 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,) - stream = io.BytesIO() - read_size = 0 - while read_size < size: - data_raw = p.read(size - read_size) - if len(data_raw) == 0: - raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size)) - read_size += len(data_raw) - stream.write(data_raw) - stream.seek(0) - return Unpickler(stream).load() + return util.read_pickled_object(p) def _poll(self): assert os.getpid() == self.parent_pid diff --git a/returnn/sprint/extern_interface.py b/returnn/sprint/extern_interface.py index 73395eceef..5b9971c1bd 100644 --- a/returnn/sprint/extern_interface.py +++ b/returnn/sprint/extern_interface.py @@ -11,8 +11,8 @@ import typing from returnn.util import better_exchook import returnn.util.task_system as task_system -from returnn.util.task_system import Pickler -from returnn.util.basic import to_bool, unicode, BytesIO +from returnn.util.basic import to_bool, unicode +import returnn.util.basic as util # Start Sprint PythonSegmentOrder interface. { # We use the PythonSegmentOrder just to get an estimate (upper limit) about the number of sequences. @@ -347,15 +347,7 @@ def _send(self, data_type, args=None): :param object args: """ assert data_type is not None - import struct - - stream = BytesIO() - Pickler(stream).dump((data_type, args)) - raw_data = stream.getvalue() - assert len(raw_data) > 0 - self.pipe_c2p.write(struct.pack(" BytesIO: + """ + Read bytes from stream s into a BytesIO buffer. + Raises EOFError if not enough bytes are available. + Then read it via :func:`read_pickled_object`. + """ + stream = BytesIO() + read_size = 0 + while read_size < size: + data_raw = p.read(size - read_size) + if len(data_raw) == 0: + raise EOFError("expected to read %i bytes but got EOF after %i bytes" % (size, read_size)) + read_size += len(data_raw) + stream.write(data_raw) + stream.seek(0) + return stream + + +def read_pickled_object(p: typing.BinaryIO, *, encoding=None) -> Any: + """ + Read pickled object from stream p, + after it was written via :func:`read_bytes_to_new_buffer`. + + :param p: + :param encoding: if given, passed to Unpickler + """ + from returnn.util.task_system import Unpickler + import struct + + size_raw = read_bytes_to_new_buffer(p, 4) + (size,) = struct.unpack(" 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,) + stream = read_bytes_to_new_buffer(p, size) + unpickler_kwargs = {} + if encoding: + unpickler_kwargs["encoding"] = encoding + return Unpickler(stream, **unpickler_kwargs).load() + + +def write_pickled_object(p: typing.BinaryIO, obj: Any): + """ + Writes pickled object to stream p. + """ + stream = io.BytesIO() + Pickler(stream).dump(data) + raw_data = stream.getvalue() + assert len(raw_data) > 0 + p.write(struct.pack("