-
Notifications
You must be signed in to change notification settings - Fork 0
/
beam_search_diagnostics.py
58 lines (50 loc) · 1.36 KB
/
beam_search_diagnostics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from datetime import datetime
import socket
import base64
import json
from pathlib import Path
import os
def get_diagnostic_dir():
diag_path = Path(os.getcwd()) / "outputs" / "beam_search_diagnostics"
diag_path.mkdir(parents=True, exist_ok=True)
return diag_path
def get_diagnostic_info():
try:
username = os.getlogin()
except OSError:
username = "unknown_user"
d = {
"t": datetime.utcnow().isoformat(),
"h": socket.gethostname(),
"u": username
}
s = base64.b64encode(json.dumps(d).encode("utf-8")).decode("utf-8")
return s
def record_train_diagnostics(data, iter):
file = get_diagnostic_dir() / f"{iter:06}.json"
file.write_text(data)
def format_example_sentence(source, target, hypothesis_beam, iter):
hypotheses = [
{
"hypothesis": h.value,
"score":h.score
}
for h in hypothesis_beam
]
formatted_json = json.dumps(
{
"example_source": source,
"example_target": target,
"hypotheses": hypotheses,
"diagnostic_info": get_diagnostic_info()
},
ensure_ascii=False,
indent=4
)
record_train_diagnostics(formatted_json, iter)
return f"""## Example of translation with beam search @ Iteration {iter}:
```
{formatted_json}
```
"""