Skip to content

Commit

Permalink
TF engine, save meta info in LR file, like PT
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 10, 2023
1 parent f4c457e commit 2246c41
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions returnn/tf/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,13 @@ def train_epoch(self):
train_batches = self.dataset_batches["train"]

self.updater.set_learning_rate(self.learning_rate, session=self.tf_session)
self.learning_rate_control.epoch_data[self.epoch].meta.update(
{
"global_train_step": self.global_train_step,
"effective_learning_rate": self.tf_session.run(self.updater.learning_rate),
}
)

trainer = Runner(
engine=self,
dataset_name="train",
Expand Down Expand Up @@ -1810,6 +1817,12 @@ def train_epoch(self):
self.learning_rate_control.set_epoch_error(
self.epoch, {"train_score": trainer.score, "train_error": trainer.error}
)
self.learning_rate_control.epoch_data[self.epoch].meta.update(
{
"epoch_num_train_steps": trainer.num_steps,
"epoch_train_time_secs": round(trainer.elapsed),
}
)
if self._do_save():
self.learning_rate_control.save()

Expand Down

0 comments on commit 2246c41

Please sign in to comment.