-
Notifications
You must be signed in to change notification settings - Fork 1
/
logger.py
44 lines (33 loc) · 1.53 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
import pathlib
from datetime import datetime
import os
class Logger(object):
def __init__(self, log_dir):
self.writer = tf.summary.FileWriter(log_dir)
def log_scalar(self, tag, value, step):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
simple_value=value)])
self.writer.add_summary(summary, step)
def setup_logger(args):
if args.logger:
if args.alg == 'nns':
logger_nns = Logger(pathlib.Path(os.getcwd()) / 'train_logs' / f'RL-agent-{datetime.now().strftime("%d%b%y_%I%M%p")}')
logger_heft = None
elif args.alg == 'heft':
logger_nns = None
logger_heft = Logger(pathlib.Path(os.getcwd()) / 'train_logs' / f'RL-agent-{datetime.now().strftime("%d%b%y_%I%M%p")}')
elif args.alg == 'compare':
base_dir = pathlib.Path(os.getcwd()) / 'train_logs' / f'RL-agent-{datetime.now().strftime("%d%b%y_%I%M%p")}'
logger_nns = Logger(log_dir=base_dir / 'nns')
logger_heft = Logger(log_dir=base_dir / 'heft')
else:
logger_nns = None
logger_heft = None
return logger_nns, logger_heft
def setup_logger_all():
base_dir = pathlib.Path(os.getcwd()) / 'train_logs' / f'RL-agent-{datetime.now().strftime("%d%b%y_%I%M%p")}'
logger_nns = Logger(log_dir=base_dir / 'nns')
logger_heft = Logger(log_dir=base_dir / 'heft')
logger_dqts = Logger(log_dir=base_dir / 'dqts')
return logger_nns, logger_dqts, logger_heft