Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jul 9, 2024
1 parent 6151c54 commit db2d459
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions users/zeyer/model_interfaces/model_with_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Dict, Set, Union
import dataclasses
from i6_core.returnn.training import Checkpoint as _TfCheckpoint, PtCheckpoint as _PtCheckpoint
from i6_core.returnn.training import ReturnnTrainingJob, Checkpoint as _TfCheckpoint, PtCheckpoint as _PtCheckpoint
from i6_experiments.users.zeyer.returnn.training import default_returnn_keep_epochs

if TYPE_CHECKING:
from sisyphus import tk as _tk
from i6_core.returnn.training import ReturnnTrainingJob
from .model import ModelDef, ModelDefWithCfg
from .recog import RecogDef

Expand Down Expand Up @@ -110,6 +109,12 @@ def from_training_job(
num_pretrain_epochs=num_pretrain_epochs,
)

def get_training_job(self) -> ReturnnTrainingJob:
"""get training job (assuming the scores_and_learning_rates comes from it)"""
job = self.scores_and_learning_rates.creator
assert isinstance(job, ReturnnTrainingJob), f"scores_and_learning_rates {self.scores_and_learning_rates}"
return job

@property
def last_fixed_epoch_idx(self) -> int:
"""last epoch"""
Expand Down

0 comments on commit db2d459

Please sign in to comment.