-
Notifications
You must be signed in to change notification settings - Fork 66
/
graph_training.py
69 lines (59 loc) · 1.8 KB
/
graph_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import numpy as np
import matplotlib.pyplot as plt
from src.graph_globals import global_params
from src.graphs import graph, multi_line, get_cmap
from src.picklefuncs import load_data
def get_headers(fp):
#read header
with open(fp, 'r') as f:
header = f.readline()
headers = header.split(',')
#remove last header escape char
headers[-1] = headers[-1][:-1]
print(headers)
return headers
def get_data(fp):
data = np.loadtxt(fp, delimiter=',', skiprows=1).T
if data.ndim == 1:
return [ [d for d in data] ]
else:
return [d for d in data]
def graph_data(data, labels, metric):
f, ax = plt.subplots(1,1)
#get distinct colours in a cmap
cmap = get_cmap(len(labels))
colours = [ cmap(i) for i in range(len(labels)) ]
graph( ax, data, multi_line( ax, data, colours, labels),
xtitle='Time',
ytitle_pad = (metric, 60),
title='Training Updates Progress',
legend=(0.92, 0.92),
grid=True)
#f.suptitle(metric_title)
#display graph
plt.show()
def graph_metric(path, metric):
newest_fp = [fp for fp in sorted(os.listdir(path)) if metric in fp][-1]
print(metric)
print(newest_fp)
fp = path+newest_fp
labels = get_headers(fp)
data = get_data(fp)
print(data)
graph_data(data, labels, metric)
def main():
global_params()
path = 'tmp/'
metrics = ['replay', 'updates', 'nexp']
for m in metrics:
graph_metric(path, m)
'''
newest_fp = sorted(os.listdir(path))[-1]
fp = path+newest_fp
labels = get_headers(fp)
data = get_data(fp)
graph_data(data, labels)
'''
if __name__ == '__main__':
main()