Skip to content

Commit

Permalink
Torch optimizer, extended format
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 4, 2023
1 parent 466ebc7 commit b4dbadd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
35 changes: 33 additions & 2 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import typing
from typing import Any, Set, Dict, Optional

import returnn
from returnn.log import log
from returnn.util.basic import RefIdEq

Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0):
else:
raise NotImplementedError("not implemented for not callable dynamic_learning_rate")

self._optimizer_opts = None
self.optimizer = None # type: typing.Optional[torch.optim.Optimizer]

self._grad_clip_global_norm = self.config.float("gradient_clip_global_norm", 0.0)
Expand Down Expand Up @@ -192,6 +194,7 @@ def create_optimizer(self):
optimizer_opts = self.config.typed_value("optimizer", None)
if optimizer_opts is None:
raise ValueError("config field 'optimizer' needs to be set explicitely for the Torch backend")
self._optimizer_opts = optimizer_opts
self.optimizer = self._create_optimizer(optimizer_opts)

def load_optimizer(self, filename):
Expand All @@ -202,7 +205,11 @@ def load_optimizer(self, filename):
"""
print("Load optimizer %s" % filename, file=log.v4)
optimizer_state = torch.load(filename, map_location=self._device)
self.optimizer.load_state_dict(optimizer_state)
assert isinstance(optimizer_state, dict), f"optimizer_state is not a dict but {type(optimizer_state)}"
if "optimizer" not in optimizer_state and "param_groups" in optimizer_state and "state" in optimizer_state:
# Old format, convert to new format.
optimizer_state = {"optimizer": optimizer_state}
self.optimizer.load_state_dict(optimizer_state["optimizer"])
# https://github.com/rwth-i6/returnn/issues/1345
del optimizer_state
gc.collect()
Expand All @@ -217,13 +224,37 @@ def save_optimizer(self, filename):
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)

# We use optimizer.state_dict() below.
# That will only save param order indices
# but not the name of the parameters.
# We also save a mapping of parameter indices to names.
param_id_to_name = {} # id -> name
for name, p in self.network.named_parameters():
param_id_to_name[id(p)] = name
param_names = [] # param_idx -> name
for group in self.optimizer.param_groups:
for p in group["params"]:
param_names.append(param_id_to_name[id(p)])

print("Save optimizer under %s" % filename, file=log.v4)
# First write to a temp-file, to be sure that writing happens without errors,
# and only afterward rename to the target file.
tmp_filename = filename + ".tmp_write"
if os.path.exists(tmp_filename):
os.unlink(tmp_filename)
torch.save(self.optimizer.state_dict(), tmp_filename)
torch.save(
{
"optimizer": self.optimizer.state_dict(),
"optimizer_class_name": self.optimizer.__class__.__name__,
"optimizer_opts": self._optimizer_opts,
"param_names": param_names,
"epoch": self._current_epoch,
"step": self._current_train_step,
"effective_learning_rate": self.get_effective_learning_rate(),
"returnn_version": returnn.__version__,
},
tmp_filename,
)
os.rename(tmp_filename, filename)

def get_optimizer(self):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_torch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import _setup_test_env # noqa
import sys
import unittest
import tempfile
import numpy
import torch

from returnn.util import better_exchook
from returnn.config import Config, global_config_ctx
from returnn.tensor import TensorDict, Tensor
from returnn.torch.engine import Engine
from returnn.torch.updater import Updater
import returnn.frontend as rf
from returnn.forward_iface import ForwardCallbackIface
from returnn.datasets import init_dataset
Expand Down Expand Up @@ -382,6 +384,20 @@ def test_data_loader_oggzip():
assert batches == [[[12, 8, 9, 11], [16, 0, 0, 0]], [[6, 25, 18, 20, 5], [28, 10, 28, 14, 0]], [[17, 23]]]


def test_load_optimizer_old_format():
config = Config(dict(optimizer={"class": "adamw", "weight_decay": 1e-3}))
model = torch.nn.Linear(7, 5)
updater = Updater(config=config, network=model, device=torch.device("cpu"))
updater.create_optimizer()

with tempfile.TemporaryDirectory(prefix="returnn_test_load_optimizer_old_format") as tmp_dir:
torch.save(updater.optimizer.state_dict(), tmp_dir + "/model.opt.old_format.pt")
updater.load_optimizer(tmp_dir + "/model.opt.old_format.pt")

updater.save_optimizer(tmp_dir + "/model.opt.new_format.pt")
updater.load_optimizer(tmp_dir + "/model.opt.new_format.pt")


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down

0 comments on commit b4dbadd

Please sign in to comment.