-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization.py
124 lines (95 loc) · 3.85 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import typing
import os
import logging
from abc import ABC, abstractmethod
from pathlib import Path
import torch.nn as nn
from tensorboardX import SummaryWriter
try:
import wandb
WANDB_FOUND = True
except ImportError:
WANDB_FOUND = False
logger = logging.getLogger(__name__)
class TAPEVisualizer(ABC):
"""Base class for visualization in TAPE"""
@abstractmethod
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
raise NotImplementedError
@abstractmethod
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
raise NotImplementedError
@abstractmethod
def watch(self, model: nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
raise NotImplementedError
class DummyVisualizer(TAPEVisualizer):
"""Dummy class that doesn't do anything. Used for non-master branches."""
def __init__(self,
log_dir: typing.Union[str, Path] = '',
exp_name: str = '',
debug: bool = False):
pass
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
pass
def watch(self, model: nn.Module) -> None:
pass
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
pass
class TBVisualizer(TAPEVisualizer):
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
log_dir = Path(log_dir) / exp_name
logger.info(f"tensorboard file at: {log_dir}")
self.logger = SummaryWriter(log_dir=str(log_dir))
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
logger.warn("Cannot log config when using a TBVisualizer. "
"Configure wandb for this functionality")
def watch(self, model: nn.Module) -> None:
logger.warn("Cannot watch models when using a TBVisualizer. "
"Configure wandb for this functionality")
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
for name, value in metrics_dict.items():
self.logger.add_scalar(split + "/" + name, value, step)
class WandBVisualizer(TAPEVisualizer):
def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
if not WANDB_FOUND:
raise ImportError("wandb module not available")
if debug:
os.environ['WANDB_MODE'] = 'dryrun'
if 'WANDB_PROJECT' not in os.environ:
# Want the user to set the WANDB_PROJECT.
logger.warning("WANDB_PROJECT environment variable not found, "
"not logging to app.wandb.ai")
os.environ['WANDB_MODE'] = 'dryrun'
wandb.init(dir=log_dir, name=exp_name)
def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
wandb.config.update(config)
def watch(self, model: nn.Module):
wandb.watch(model)
def log_metrics(self,
metrics_dict: typing.Dict[str, float],
split: str,
step: int):
wandb.log({f"{split.capitalize()} {name.capitalize()}": value
for name, value in metrics_dict.items()}, step=step)
def get(log_dir: typing.Union[str, Path],
exp_name: str,
local_rank: int,
debug: bool = False) -> TAPEVisualizer:
if local_rank not in (-1, 0):
return DummyVisualizer(log_dir, exp_name, debug)
elif WANDB_FOUND:
return WandBVisualizer(log_dir, exp_name, debug)
else:
return TBVisualizer(log_dir, exp_name, debug)