Skip to content

Commit

Permalink
dynamic_learning_rate, extra checks
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 3, 2023
1 parent 5c82ab3 commit 4e76da6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions returnn/tf/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def get_current_step_learning_rate(self):
assert any(
[arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]
), "please specify **kwargs in dynamic_learning_rate for future compatibility"
if "epoch" in signature.parameters:
raise NotImplementedError("TF updater: dynamic_learning_rate with epoch not supported currently")
lr = learning_rate_function(
network=self.network, global_train_step=self.global_train_step, learning_rate=lr
)
Expand Down
2 changes: 2 additions & 0 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(self, *, config, network, device, initial_learning_rate=1.0):
assert any(
[arg.kind == inspect.Parameter.VAR_KEYWORD for arg in signature.parameters.values()]
), "please specify **kwargs in dynamic_learning_rate for future compatibility"
if "network" in signature.parameters:
raise ValueError("Torch updater: dynamic_learning_rate network is TF specific")
else:
raise NotImplementedError("not implemented for not callable dynamic_learning_rate")

Expand Down

0 comments on commit 4e76da6

Please sign in to comment.