-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
131 lines (92 loc) · 4.16 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
'''
* Based on coda prompt here
* https://github.com/GT-RIPL/CODA-Prompt
* Build our CDL model on CODAPrompt baseline(DualPrompt and L2P)
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import argparse
import torch
import numpy as np
import yaml
import json
import random
from trainer import Trainer
def create_args():
parser = argparse.ArgumentParser()
# Standard Args
parser.add_argument('--random_s', type=int, default=1, help="The random seed")
parser.add_argument('--gpuid', nargs="+", type=int, default=[0],
help="The list of gpuid, ex:--gpuid 3 1. Negative value means cpu-only")
parser.add_argument('--log_dir', type=str, default="outputs/out",
help="Save experiments results in dir for future plotting!")
parser.add_argument('--learner_type', type=str, default='default', help="The type (filename) of learner")
parser.add_argument('--learner_name', type=str, default='NormalNN', help="The class name of learner")
parser.add_argument('--debug_mode', type=int, default=0, metavar='N',
help="activate learner specific settings for debug_mode")
parser.add_argument('--overwrite', type=int, default=0, metavar='N', help='Train regardless of whether saved model exists')
# CL Args
parser.add_argument('--upper_bound_flag', default=False, action='store_true', help='Upper bound')
parser.add_argument('--memory', type=int, default=0, help="size of memory for replay")
parser.add_argument('--DW', default=False, action='store_true', help='dataset balancing')
parser.add_argument('--prompt_param', nargs="+", type=float, default=[1, 1, 1],
help="e prompt pool size, e prompt length, g prompt length")
# The Teacher and Student Model
parser.add_argument('--t_model', default='vit_base_patch16_224', type=str, metavar='MODEL', help='Name of t_model to train')
parser.add_argument('--s_model', default='vit_tiny_patch16_224', type=str, metavar='MODEL', help='Name of s_model to train')
# KD Args
parser.add_argument('--kd_alpha', type=float, default=0.5, help="alpha of distillation loss")
parser.add_argument('--Soft_T', type=float, default=2., help="temperature for distillation")
# Config Arg
parser.add_argument('--config', type=str, default="configs/config.yaml",
help="yaml experiment config input")
# Config the KD methods
parser.add_argument('--KD_method', type=str, default='KD', help="The KD methods")
return parser
def get_args(argv):
parser=create_args()
args = parser.parse_args(argv)
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
config.update(vars(args))
return argparse.Namespace(**config)
# want to save everything printed to outfile
class Logger(object):
def __init__(self, name):
self.terminal = sys.stdout
self.log = open(name, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.log.flush()
if __name__ == '__main__':
args = get_args(sys.argv[1:])
# determinstic backend
torch.backends.cudnn.deterministic=True
# duplicate output stream to output file
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
log_out = args.log_dir + '/output.log'
sys.stdout = Logger(log_out)
# save args
with open(args.log_dir + '/args.yaml', 'w') as yaml_file:
yaml.dump(vars(args), yaml_file, default_flow_style=False)
metric_keys = ['acc','time',]
print('************************************')
print('* START TRAINING ')
print('************************************')
# set random seeds
seed = args.random_s
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# set up a trainer
trainer = Trainer(args, seed, metric_keys)
# init total run metrics storage
max_task = trainer.max_task
# train model
trainer.train(args)