-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvictini_trainer.py
41 lines (33 loc) · 1.74 KB
/
victini_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import logging
import random
from pathlib import Path
import torch
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from alphagogoat.constants import DEVICE
from alphagogoat.victini import Victini, train, evaluate
from alphagogoat.utils import make_victini_data
def main():
victini_path = "victini.pth"
json_files = [filepath for filepath in Path("cache/replays").iterdir() if filepath.name.endswith('.json')]
train_files, test_files = json_files[:-100], json_files[-100:]
reps = 100
victini = Victini(964)
# random.shuffle(train_files)
# train_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(train_files))
# train_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(train_files[:100])) # MEDIUM
# train_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(train_files[:30])) # SMALL
# train_data = [make_victini_data(f) for f in tqdm(json_files[:1])] # SINGLE-PROCESS DEBUGGING
if Path(victini_path).exists():
victini.load_state_dict(torch.load(victini_path))
for _ in range(reps):
train_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(train_files[:100])) # MEDIUM
# random.shuffle(train_files)
# train_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(train_files[:100])) # MEDIUM
train(victini, train_data, lr=0.0001, weight_decay=0, discount=0)
torch.save(victini.state_dict(), victini_path)
test_data = Parallel(n_jobs=4)(delayed(make_victini_data)(filepath) for filepath in tqdm(test_files))
evaluate(victini, test_data)
if __name__ == "__main__":
main()