Skip to content

Commit

Permalink
Cleanup Sprint serialization
Browse files Browse the repository at this point in the history
Small followup to #1456
  • Loading branch information
albertz committed Nov 8, 2023
1 parent a3d1094 commit 09adef0
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 90 deletions.
33 changes: 6 additions & 27 deletions returnn/datasets/sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from returnn.datasets.basic import Dataset, DatasetSeq
from .cached2 import CachedDataset2
from returnn.log import log
from returnn.util.task_system import Unpickler, numpy_copy_and_set_unused
from returnn.util.basic import eval_shell_str, interrupt_main, unicode, PY3, BytesIO, close_all_fds_except
from returnn.util.task_system import numpy_copy_and_set_unused
from returnn.util.basic import eval_shell_str, interrupt_main, unicode, PY3, close_all_fds_except
import returnn.util.basic as util


class SprintDatasetBase(Dataset):
Expand Down Expand Up @@ -913,31 +914,9 @@ def _read_next_raw(self):
:return: (data_type, args)
:rtype: (str, object)
"""
import struct

size_raw = self.pipe_c2p[0].read(4)
if len(size_raw) < 4:
raise EOFError
(size,) = struct.unpack("<i", size_raw)
assert size > 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,)
stream = BytesIO()
read_size = 0
while read_size < size:
data_raw = self.pipe_c2p[0].read(size - read_size)
if len(data_raw) == 0:
raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size))
read_size += len(data_raw)
stream.write(data_raw)
stream.seek(0)
try:
if PY3:
# encoding is for converting Python2 strings to Python3.
# Cannot use utf8 because Numpy will also encode the data as strings and there we need it as bytes.
data_type, args = Unpickler(stream, encoding="bytes").load()
else:
data_type, args = Unpickler(stream).load()
except EOFError:
raise Exception("%s: parse error of %i bytes (%r)" % (self, size, stream.getvalue()))
# encoding is for converting Python2 strings to Python3.
# Cannot use utf8 because Numpy will also encode the data as strings and there we need it as bytes.
data_type, args = util.read_pickled_object(self.pipe_c2p[0], encoding="bytes")
return data_type, args

def _join_child(self, wait=True, expected_exit_status=None):
Expand Down
30 changes: 4 additions & 26 deletions returnn/sprint/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
import os
import numpy
import typing
import io
import struct
from threading import Condition

import returnn.__main__ as rnn
import returnn.util.task_system as task_system
import returnn.util.debug as debug
from returnn.util.task_system import Pickler, Unpickler, numpy_set_unused
from returnn.util.task_system import numpy_set_unused
from returnn.util.basic import to_bool, long
import returnn.util.basic as util

InitTypes = set()
Verbose = False # disables all per-segment log messages
Expand Down Expand Up @@ -434,31 +433,10 @@ def _get_loss_and_error_signal_via_sprint_callback(self, seg_name, orthography,
return loss, error_signal

def _send(self, data):
stream = io.BytesIO()
Pickler(stream).dump(data)
raw_data = stream.getvalue()
assert len(raw_data) > 0
self.pipe_c2p.write(struct.pack("<i", len(raw_data)))
self.pipe_c2p.write(raw_data)
self.pipe_c2p.flush()
util.write_pickled_object(self.pipe_c2p, data)

def _read(self):
p = self.pipe_p2c
size_raw = p.read(4)
if len(size_raw) < 4:
raise EOFError
(size,) = struct.unpack("<i", size_raw)
assert size > 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,)
stream = io.BytesIO()
read_size = 0
while read_size < size:
data_raw = p.read(size - read_size)
if len(data_raw) == 0:
raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size))
read_size += len(data_raw)
stream.write(data_raw)
stream.seek(0)
return Unpickler(stream).load()
return util.read_pickled_object(self.pipe_p2c)

def close(self):
"""
Expand Down
29 changes: 4 additions & 25 deletions returnn/sprint/error_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
import atexit
import signal
import typing
import struct
import io
from threading import RLock, Thread
import returnn.util.task_system as task_system
from returnn.util.task_system import Pickler, Unpickler, numpy_set_unused
from returnn.util.task_system import numpy_set_unused
from returnn.util.basic import eval_shell_str, make_hashable, close_all_fds_except
import returnn.util.basic as util
from returnn.log import log


Expand Down Expand Up @@ -219,32 +218,12 @@ def _build_sprint_args(self):
def _send(self, v):
assert os.getpid() == self.parent_pid
p = self.pipe_p2c[1] # see _start_child
stream = io.BytesIO()
Pickler(stream).dump(v)
raw_data = stream.getvalue()
assert len(raw_data) > 0
p.write(struct.pack("<i", len(raw_data)))
p.write(raw_data)
p.flush()
util.write_pickled_object(p, v)

def _read(self):
assert os.getpid() == self.parent_pid
p = self.pipe_c2p[0] # see _start_child
size_raw = p.read(4)
if len(size_raw) < 4:
raise EOFError
(size,) = struct.unpack("<i", size_raw)
assert size > 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,)
stream = io.BytesIO()
read_size = 0
while read_size < size:
data_raw = p.read(size - read_size)
if len(data_raw) == 0:
raise EOFError("%s: expected to read %i bytes but got EOF after %i bytes" % (self, size, read_size))
read_size += len(data_raw)
stream.write(data_raw)
stream.seek(0)
return Unpickler(stream).load()
return util.read_pickled_object(p)

def _poll(self):
assert os.getpid() == self.parent_pid
Expand Down
14 changes: 3 additions & 11 deletions returnn/sprint/extern_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import typing
from returnn.util import better_exchook
import returnn.util.task_system as task_system
from returnn.util.task_system import Pickler
from returnn.util.basic import to_bool, unicode, BytesIO
from returnn.util.basic import to_bool, unicode
import returnn.util.basic as util

# Start Sprint PythonSegmentOrder interface. {
# We use the PythonSegmentOrder just to get an estimate (upper limit) about the number of sequences.
Expand Down Expand Up @@ -347,15 +347,7 @@ def _send(self, data_type, args=None):
:param object args:
"""
assert data_type is not None
import struct

stream = BytesIO()
Pickler(stream).dump((data_type, args))
raw_data = stream.getvalue()
assert len(raw_data) > 0
self.pipe_c2p.write(struct.pack("<i", len(raw_data)))
self.pipe_c2p.write(raw_data)
self.pipe_c2p.flush()
util.write_pickled_object(self.pipe_c2p, (data_type, args))

def add_new_data(self, segment_name, features, targets):
"""
Expand Down
54 changes: 53 additions & 1 deletion returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from __future__ import annotations
from typing import Optional, Generic, TypeVar, Iterable, Tuple, Dict, List, Callable
from typing import Optional, Any, Generic, TypeVar, Iterable, Tuple, Dict, List, Callable

import subprocess
from subprocess import CalledProcessError
Expand Down Expand Up @@ -2597,6 +2597,58 @@ def pickle_loads(s):
return c


def read_bytes_to_new_buffer(p: typing.BinaryIO, size: int) -> BytesIO:
"""
Read bytes from stream s into a BytesIO buffer.
Raises EOFError if not enough bytes are available.
Then read it via :func:`read_pickled_object`.
"""
stream = BytesIO()
read_size = 0
while read_size < size:
data_raw = p.read(size - read_size)
if len(data_raw) == 0:
raise EOFError("expected to read %i bytes but got EOF after %i bytes" % (size, read_size))
read_size += len(data_raw)
stream.write(data_raw)
stream.seek(0)
return stream


def read_pickled_object(p: typing.BinaryIO, *, encoding=None) -> Any:
"""
Read pickled object from stream p,
after it was written via :func:`read_bytes_to_new_buffer`.
:param p:
:param encoding: if given, passed to Unpickler
"""
from returnn.util.task_system import Unpickler
import struct

size_raw = read_bytes_to_new_buffer(p, 4)
(size,) = struct.unpack("<i", size_raw)
assert size > 0, "%s: We expect to get some non-empty package. Invalid Python mod in Sprint?" % (self,)
stream = read_bytes_to_new_buffer(p, size)
unpickler_kwargs = {}
if encoding:
unpickler_kwargs["encoding"] = encoding
return Unpickler(stream, **unpickler_kwargs).load()


def write_pickled_object(p: typing.BinaryIO, obj: Any):
"""
Writes pickled object to stream p.
"""
stream = io.BytesIO()
Pickler(stream).dump(data)
raw_data = stream.getvalue()
assert len(raw_data) > 0
p.write(struct.pack("<i", len(raw_data)))
p.write(raw_data)
p.flush()


def load_txt_vector(filename):
"""
Expect line-based text encoding in file.
Expand Down

0 comments on commit 09adef0

Please sign in to comment.