diff --git a/users/zeyer/model_interfaces/model_with_checkpoints.py b/users/zeyer/model_interfaces/model_with_checkpoints.py index 3b1da25df..36a960673 100644 --- a/users/zeyer/model_interfaces/model_with_checkpoints.py +++ b/users/zeyer/model_interfaces/model_with_checkpoints.py @@ -5,12 +5,11 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Dict, Set, Union import dataclasses -from i6_core.returnn.training import Checkpoint as _TfCheckpoint, PtCheckpoint as _PtCheckpoint +from i6_core.returnn.training import ReturnnTrainingJob, Checkpoint as _TfCheckpoint, PtCheckpoint as _PtCheckpoint from i6_experiments.users.zeyer.returnn.training import default_returnn_keep_epochs if TYPE_CHECKING: from sisyphus import tk as _tk - from i6_core.returnn.training import ReturnnTrainingJob from .model import ModelDef, ModelDefWithCfg from .recog import RecogDef @@ -110,6 +109,12 @@ def from_training_job( num_pretrain_epochs=num_pretrain_epochs, ) + def get_training_job(self) -> ReturnnTrainingJob: + """get training job (assuming the scores_and_learning_rates comes from it)""" + job = self.scores_and_learning_rates.creator + assert isinstance(job, ReturnnTrainingJob), f"scores_and_learning_rates {self.scores_and_learning_rates}" + return job + @property def last_fixed_epoch_idx(self) -> int: """last epoch"""