Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickling dim tags #937

Merged
merged 9 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ def short_repr(self):
desc += "{ctx=%s}" % self.control_flow_ctx.repr_inner()
return desc

def __getstate__(self):
d = vars(self).copy()
d["batch"] = None
d["_same_as_tb"] = None
d["_same_for_batch_ctx"] = {}
d["dyn_size_ext"] = None
d["kind"] = self.kind.name if self.kind else None
return d

def __setstate__(self, state):
self.__dict__.update(state)
if self.kind is not None:
self.kind = {v.name: v for v in Dim.Types.Types}[self.kind]
if self.is_batch_dim():
self.same_as = batch_dim

def __copy__(self):
"""
Normally we would not want to get a new tag with ``tag != copy(tag)``.
Expand Down Expand Up @@ -2042,6 +2058,9 @@ def short_repr(self):
# "x" is the Theano-style shortcut for a broadcast dim.
return "&".join([dim.short_repr() for dim in self.virtual_dims] or ["Bx"])

def __getstate__(self):
raise Exception("Pickling of BatchInfo is not supported. (%s)" % self)

@property
def dim(self):
"""
Expand Down Expand Up @@ -3016,6 +3035,12 @@ def __repr__(self):
def __hash__(self):
return id(self)

def __getstate__(self):
d = vars(self)
d["_batch"] = None # BatchInfo pickling not supported
d["_placeholder"] = None # do not store the TF tensors
return d

def _adapt_batch_consistent_dim_tags(self):
if not self.batch: # uninitialized
return
Expand Down
11 changes: 8 additions & 3 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.name

def __getstate__(self):
raise Exception("Cannot pickle Entity object. (%r)" % self)


class OptionalNotImplementedError(NotImplementedError):
"""
Expand Down Expand Up @@ -2520,16 +2523,18 @@ def pip_install(*pkg_names):

def pip_check_is_installed(pkg_name):
"""
:param str pkg_name: without version, e.g. just "tensorflow" or "tensorflow-gpu"
:param str pkg_name: without version, e.g. just "tensorflow", or with version, e.g. "tensorflow==1.2.3"
:rtype: bool
"""
if not _pip_installed_packages:
py = sys.executable
pip_path = which_pip()
cmd = [py, pip_path, "freeze"]
for line in sys_exec_out(*cmd).splitlines():
if line:
_pip_installed_packages.add(line[:line.index("==")])
if line and "==" in line:
if "==" not in pkg_name:
line = line[:line.index("==")]
_pip_installed_packages.add(line)
return pkg_name in _pip_installed_packages


Expand Down
2 changes: 1 addition & 1 deletion tests/pycharm-inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def setup_pycharm_python_interpreter(pycharm_dir):
if not pip_check_is_installed("Theano"):
pip_install("theano==0.9")
# Note: Horovod will usually fail to install in this env.
for pkg in ["typing", "librosa", "PySoundFile", "nltk", "matplotlib", "mpi4py", "pycodestyle"]:
for pkg in ["typing", "librosa==0.8.1", "PySoundFile", "nltk", "matplotlib", "mpi4py", "pycodestyle"]:
if not pip_check_is_installed(pkg):
try:
pip_install(pkg)
Expand Down
69 changes: 69 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6480,6 +6480,75 @@ def test_batch_norm_args():
assert batch_norm_args == layer_args # different arguments in BatchNormLayer and LayerBase.batch_norm()


def test_pickle_dim_tags():
# Test for pickling net dict and extern data.
# https://github.com/rwth-i6/returnn_common/issues/104
# What kind of network we pickle here is not really relevant.
# I took this from some other test.
# It should just involve some dim tags.
from returnn.tf.util.data import batch_dim, ImplicitDynSizeDim
n_batch, n_time, n_ts, n_in, n_out = 2, 3, 6, 7, 11
time_dim = SpatialDim("time")
feat_dim = FeatureDim("in-feature", dimension=n_in)
ts_dim = SpatialDim("ts", dimension=n_ts)
out_dim = FeatureDim("out-feature", dimension=n_out)
config = Config({
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}},
"network": {
"t": {
"class": "range_in_axis", "axis": "t", "from": "data",
"out_shape": {time_dim, ImplicitDynSizeDim(batch_dim)}}, # (T,)
"range": {"class": "range", "limit": n_ts, "out_spatial_dim": ts_dim}, # (Ts,)
"add_t": {
"class": "combine", "kind": "add", "from": ["t", "range"],
"out_shape": {time_dim, ts_dim, ImplicitDynSizeDim(batch_dim)}}, # (T,Ts)
"t_rel_var": {"class": "variable", "shape": (ts_dim, out_dim), "init": "glorot_uniform"}, # (Ts,D)
"output": {
"class": "scatter_nd", "from": "t_rel_var", "position": "add_t", "position_axis": ts_dim,
"out_spatial_dim": time_dim, "filter_invalid_indices": True} # (T,T,D)
}
})

# First test as-is.
with make_scope() as session:
network = TFNetwork(config=config, train_flag=True)
network.construct_from_dict(config.typed_dict["network"])
network.initialize_params(session)
fetches = network.get_fetches_dict()
out_layer = network.get_default_output_layer()
print("out layer:", out_layer)
session.run((out_layer.output.placeholder, fetches), feed_dict=make_feed_dict(network.extern_data))

def _debug_dump(tuple_path, obj):
print("%s: %s" % (tuple_path, type(obj)))
assert isinstance(obj, (Dim, str, set, str, bool, int))

from tensorflow.python.util import nest
nest.map_structure_with_tuple_paths(_debug_dump, config.typed_dict)

# Now pickle, unpickle and test again.
# Use pickle._dumps for easier debugging.
# noinspection PyUnresolvedReferences
from pickle import _dumps as dumps
s = dumps(config.typed_dict)
import pickle
config_dict = pickle.loads(s)
new_dim_tags = config_dict["extern_data"]["data"]["dim_tags"]
new_batch, new_time, new_feat = new_dim_tags
assert isinstance(new_batch, Dim) and new_batch == batch_dim and new_batch.is_batch_dim()
assert isinstance(new_time, Dim) and new_time.is_spatial_dim() and new_time.dimension is None
assert isinstance(new_feat, Dim) and new_feat.is_feature_dim() and new_feat.dimension == n_in
config = Config(config_dict)
with make_scope() as session:
network = TFNetwork(config=config, train_flag=True)
network.construct_from_dict(config.typed_dict["network"])
network.initialize_params(session)
fetches = network.get_fetches_dict()
out_layer = network.get_default_output_layer()
print("out layer:", out_layer)
session.run((out_layer.output.placeholder, fetches), feed_dict=make_feed_dict(network.extern_data))


def test_contrastive_loss():
from returnn.tf.util.data import batch_dim
masked_time_dim = SpatialDim("masked_time")
Expand Down