This repository has been archived by the owner on Jan 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
callbacks.py
46 lines (31 loc) · 1.29 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Copyright (c) 2021 Kemal Kurniawan
from typing import Callable
from rnnr.callbacks import save
from sacred.run import Run
import torch
def update_params(opt: torch.optim.Optimizer) -> Callable[[dict], None]:
def callback(state):
opt.zero_grad()
state["loss"].backward()
opt.step()
return callback
def log_grads(run: Run, model: torch.nn.Module, every: int = 10) -> Callable[[dict], None]:
def callback(state):
if state["n_iters"] % every != 0:
return
for name, p in model.named_parameters():
if p.requires_grad:
run.log_scalar(f"grad_{name}", p.grad.norm().item(), state["n_iters"])
return callback
def log_stats(run: Run, every: int = 10) -> Callable[[dict], None]:
def callback(state):
if state["n_iters"] % every != 0:
return
for name, value in state["stats"].items():
run.log_scalar(f"batch_{name}", value, state["n_iters"])
for name, value in state.get("extra_stats", {}).items():
run.log_scalar(f"batch_{name}", value, state["n_iters"])
return callback
def save_state_dict(*args, **kwargs) -> Callable[[dict], None]:
kwargs.update({"using": lambda m, p: torch.save(m.state_dict(), p), "ext": "pth"})
return save(*args, **kwargs)