generated from HephaestusProject/template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate.py
112 lines (94 loc) · 3.67 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
from argparse import ArgumentParser, Namespace
from pathlib import Path
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer, seed_everything
from src.model.net import FCN, ShortChunkCNN_Res
from src.task.pipeline import DataPipeline
from src.task.runner import AutotaggingRunner
def get_config(args: Namespace) -> DictConfig:
parent_config_dir = Path("conf")
child_config_dir = parent_config_dir / args.dataset
model_config_dir = child_config_dir / "model"
pipeline_config_dir = child_config_dir / "pipeline"
runner_config_dir = child_config_dir / "runner"
config = OmegaConf.create()
model_config = OmegaConf.load(model_config_dir / f"{args.model}.yaml")
pipeline_config = OmegaConf.load(pipeline_config_dir / f"{args.pipeline}.yaml")
runner_config = OmegaConf.load(runner_config_dir / f"{args.runner}.yaml")
config.update(model=model_config, pipeline=pipeline_config, runner=runner_config)
return config
def main(args) -> None:
seed_everything(42)
config = get_config(args)
# prepare dataloader
pipeline = DataPipeline(pipline_config=config.pipeline)
dataset = pipeline.get_dataset(
pipeline.dataset_builder,
config.pipeline.dataset.path,
args.type,
config.pipeline.dataset.input_length,
)
dataloader = pipeline.get_dataloader(
dataset,
shuffle=False,
drop_last=True,
**pipeline.pipeline_config.dataloader.params,
)
if args.model == "ShortChunkCNN_Res":
model = ShortChunkCNN_Res(**config.model.params)
elif args.model == "FCN":
model = FCN(**config.model.params)
runner = AutotaggingRunner(model, config.runner)
checkpoint_path = (
f"exp/{args.dataset}/{args.model}/{args.runner}/{args.checkpoint}.ckpt"
)
state_dict = torch.load(checkpoint_path)
runner.load_state_dict(state_dict.get("state_dict"))
trainer = Trainer(
**config.runner.trainer.params, logger=False, checkpoint_callback=False
)
results_path = Path(f"exp/{args.dataset}/{args.model}/{args.runner}/results.json")
if results_path.exists():
with open(results_path, mode="r") as io:
results = json.load(io)
result = trainer.test(runner, test_dataloaders=dataloader)
results.update({"checkpoint": args.checkpoint, f"{args.type}": result})
else:
results = {}
result = trainer.test(runner, test_dataloaders=dataloader)
results.update({"checkpoint": args.checkpoint, f"{args.type}": result})
with open(
f"exp/{args.dataset}/{args.model}/{args.runner}/results.json", mode="w"
) as io:
json.dump(results, io, indent=4)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--model",
default="ShortChunkCNN_Res",
type=str,
choices=["FCN", "ShortChunkCNN_Res"],
)
parser.add_argument("--type", default="TEST", type=str, choices=["TEST"])
parser.add_argument("--dataset", default="mtat", type=str, choices=["mtat"])
parser.add_argument(
"--pipeline",
default="pv_AudioInput3sec",
type=str,
choices=["pv_AudioInput3sec", "pv_AudioInput30sec"],
)
parser.add_argument("--runner", default="rv00", type=str, choices=["rv00", "rv01"])
parser.add_argument("--reproduce", default=False, action="store_true")
parser.add_argument(
"--checkpoint",
default="epoch=25-roc_auc=0.8929-pr_auc=0.4043",
type=str,
choices=[
"epoch=23-roc_auc=0.9044-pr_auc=0.4403",
"epoch=25-roc_auc=0.8929-pr_auc=0.4043",
],
)
args = parser.parse_args()
main(args)