Skip to content

Commit

Permalink
Support pickling dim tags (#937)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Feb 7, 2022
1 parent 4b500fd commit d030cde
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 4 deletions.
25 changes: 25 additions & 0 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def short_repr(self):
desc += "{ctx=%s}" % self.control_flow_ctx.repr_inner()
return desc

def __getstate__(self):
d = vars(self).copy()
d["batch"] = None
d["_same_as_tb"] = None
d["_same_for_batch_ctx"] = {}
d["dyn_size_ext"] = None
d["kind"] = self.kind.name if self.kind else None
return d

def __setstate__(self, state):
self.__dict__.update(state)
if self.kind is not None:
self.kind = {v.name: v for v in Dim.Types.Types}[self.kind]
if self.is_batch_dim():
self.same_as = batch_dim

def __copy__(self):
"""
Normally we would not want to get a new tag with ``tag != copy(tag)``.
Expand Down Expand Up @@ -2042,6 +2058,9 @@ def short_repr(self):
# "x" is the Theano-style shortcut for a broadcast dim.
return "&".join([dim.short_repr() for dim in self.virtual_dims] or ["Bx"])

def __getstate__(self):
raise Exception("Pickling of BatchInfo is not supported. (%s)" % self)

@property
def dim(self):
"""
Expand Down Expand Up @@ -3016,6 +3035,12 @@ def __repr__(self):
def __hash__(self):
return id(self)

def __getstate__(self):
d = vars(self)
d["_batch"] = None # BatchInfo pickling not supported
d["_placeholder"] = None # do not store the TF tensors
return d

def _adapt_batch_consistent_dim_tags(self):
if not self.batch: # uninitialized
return
Expand Down
11 changes: 8 additions & 3 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.name

def __getstate__(self):
raise Exception("Cannot pickle Entity object. (%r)" % self)


class OptionalNotImplementedError(NotImplementedError):
"""
Expand Down Expand Up @@ -2520,16 +2523,18 @@ def pip_install(*pkg_names):

def pip_check_is_installed(pkg_name):
"""
:param str pkg_name: without version, e.g. just "tensorflow" or "tensorflow-gpu"
:param str pkg_name: without version, e.g. just "tensorflow", or with version, e.g. "tensorflow==1.2.3"
:rtype: bool
"""
if not _pip_installed_packages:
py = sys.executable
pip_path = which_pip()
cmd = [py, pip_path, "freeze"]
for line in sys_exec_out(*cmd).splitlines():
if line:
_pip_installed_packages.add(line[:line.index("==")])
if line and "==" in line:
if "==" not in pkg_name:
line = line[:line.index("==")]
_pip_installed_packages.add(line)
return pkg_name in _pip_installed_packages


Expand Down
2 changes: 1 addition & 1 deletion tests/pycharm-inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def setup_pycharm_python_interpreter(pycharm_dir):
if not pip_check_is_installed("Theano"):
pip_install("theano==0.9")
# Note: Horovod will usually fail to install in this env.
for pkg in ["typing", "librosa", "PySoundFile", "nltk", "matplotlib", "mpi4py", "pycodestyle"]:
for pkg in ["typing", "librosa==0.8.1", "PySoundFile", "nltk", "matplotlib", "mpi4py", "pycodestyle"]:
if not pip_check_is_installed(pkg):
try:
pip_install(pkg)
Expand Down
69 changes: 69 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6480,6 +6480,75 @@ def test_batch_norm_args():
assert batch_norm_args == layer_args # different arguments in BatchNormLayer and LayerBase.batch_norm()


def test_pickle_dim_tags():
# Test for pickling net dict and extern data.
# https://github.com/rwth-i6/returnn_common/issues/104
# What kind of network we pickle here is not really relevant.
# I took this from some other test.
# It should just involve some dim tags.
from returnn.tf.util.data import batch_dim, ImplicitDynSizeDim
n_batch, n_time, n_ts, n_in, n_out = 2, 3, 6, 7, 11
time_dim = SpatialDim("time")
feat_dim = FeatureDim("in-feature", dimension=n_in)
ts_dim = SpatialDim("ts", dimension=n_ts)
out_dim = FeatureDim("out-feature", dimension=n_out)
config = Config({
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}},
"network": {
"t": {
"class": "range_in_axis", "axis": "t", "from": "data",
"out_shape": {time_dim, ImplicitDynSizeDim(batch_dim)}}, # (T,)
"range": {"class": "range", "limit": n_ts, "out_spatial_dim": ts_dim}, # (Ts,)
"add_t": {
"class": "combine", "kind": "add", "from": ["t", "range"],
"out_shape": {time_dim, ts_dim, ImplicitDynSizeDim(batch_dim)}}, # (T,Ts)
"t_rel_var": {"class": "variable", "shape": (ts_dim, out_dim), "init": "glorot_uniform"}, # (Ts,D)
"output": {
"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": ts_dim,
"out_spatial_dim": time_dim, "filter_invalid_indices": True} # (T,T,D)
}
})

# First test as-is.
with make_scope() as session:
network = TFNetwork(config=config, train_flag=True)
network.construct_from_dict(config.typed_dict["network"])
network.initialize_params(session)
fetches = network.get_fetches_dict()
out_layer = network.get_default_output_layer()
print("out layer:", out_layer)
session.run((out_layer.output.placeholder, fetches), feed_dict=make_feed_dict(network.extern_data))

def _debug_dump(tuple_path, obj):
print("%s: %s" % (tuple_path, type(obj)))
assert isinstance(obj, (Dim, str, set, str, bool, int))

from tensorflow.python.util import nest
nest.map_structure_with_tuple_paths(_debug_dump, config.typed_dict)

# Now pickle, unpickle and test again.
# Use pickle._dumps for easier debugging.
# noinspection PyUnresolvedReferences
from pickle import _dumps as dumps
s = dumps(config.typed_dict)
import pickle
config_dict = pickle.loads(s)
new_dim_tags = config_dict["extern_data"]["data"]["dim_tags"]
new_batch, new_time, new_feat = new_dim_tags
assert isinstance(new_batch, Dim) and new_batch == batch_dim and new_batch.is_batch_dim()
assert isinstance(new_time, Dim) and new_time.is_spatial_dim() and new_time.dimension is None
assert isinstance(new_feat, Dim) and new_feat.is_feature_dim() and new_feat.dimension == n_in
config = Config(config_dict)
with make_scope() as session:
network = TFNetwork(config=config, train_flag=True)
network.construct_from_dict(config.typed_dict["network"])
network.initialize_params(session)
fetches = network.get_fetches_dict()
out_layer = network.get_default_output_layer()
print("out layer:", out_layer)
session.run((out_layer.output.placeholder, fetches), feed_dict=make_feed_dict(network.extern_data))


def test_contrastive_loss():
from returnn.tf.util.data import batch_dim
masked_time_dim = SpatialDim("masked_time")
Expand Down

0 comments on commit d030cde

Please sign in to comment.