-
Notifications
You must be signed in to change notification settings - Fork 0
/
hp_search.py
151 lines (130 loc) · 5.95 KB
/
hp_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import time
from datetime import datetime
import logging
import argparse
import csv
from ast import literal_eval
import subprocess
import transformers
from DAMF import read_data
from DAMF import DomainAdaptTrainer
from DAMF import evaluate
from DAMF import feature_embedding_analysis
from DAMF import read_config, set_seed
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def read_command_args(args):
"""Read arguments from command line."""
parser = argparse.ArgumentParser(
description='Domain adaptation model for moral foundation inference.')
parser.add_argument('-c',
'--config_dir',
type=str,
required=True,
help='configuration file dir that specifies hyperparameters etc')
parser.add_argument('-i',
'--data_dir',
type=str,
required=True,
help='input data directory')
parser.add_argument('-o',
'--output_dir',
type=str,
required=True,
help='output directory')
parser.add_argument('--lambda_trans_lst',
type=str,
required=True,
help='a list')
parser.add_argument('--lambda_rec_lst',
type=str,
required=True,
help='a list')
parser.add_argument('--gamma_lst',
type=str,
required=True,
help='a list')
command_args = parser.parse_args()
## add dirs into args
curr_dir = os.path.dirname(os.path.realpath(__file__))
# data dir
args['data_dir'] = os.path.join(curr_dir, command_args.data_dir)
# output_dir
args['output_dir'] = os.path.join(curr_dir, command_args.output_dir)
if not os.path.exists(args['output_dir']):
os.makedirs(args['output_dir'])
# config dir
args['config_dir'] = os.path.join(curr_dir, command_args.config_dir)
# hp
args['lambda_trans_lst'] = literal_eval(command_args.lambda_trans_lst)
args['lambda_rec_lst'] = literal_eval(command_args.lambda_rec_lst)
args['gamma_lst'] = literal_eval(command_args.gamma_lst)
return args
if __name__ == '__main__':
# logger
try:
logfilename = os.environ["SLURM_JOB_ID"]
except:
logfilename = datetime.now().strftime("%Y%m%d%H%M%S")
logging.basicConfig(filename=logfilename + '.log',
format="%(message)s",
level=logging.INFO)
# args
args = {}
args = read_command_args(args)
args = read_config(os.path.dirname(os.path.realpath(__file__)), args)
# loop thru seeds
start_time = time.time()
best_accu = 0.0
for lambda_trans in args['lambda_trans_lst']:
for lambda_rec in args['lambda_rec_lst']:
for gamma in args['gamma_lst']:
logging.info(f"\nStart HP search: lambda_trans={lambda_trans}, lambda_rec={lambda_rec}, gamma={gamma}")
args['lambda_rec'] = lambda_rec
args['lambda_trans'] = lambda_trans
args['gamma'] = gamma
if lambda_rec == 0:
args['reconstruction'] = False
else:
args['reconstruction'] = True
if lambda_trans == 0:
args['transformation'] = False
else:
args['transformation'] = True
set_seed(args['seed'])
datasets = read_data(args['data_dir'],
args['pretrained_dir'],
args['n_mf_classes'],
args['train_domain'],
args['test_domain'],
args['semi_supervised'],
args['aflite'],
seed=args['seed'],
train_frac=0.8)
trainer = DomainAdaptTrainer(datasets, args)
accu = trainer.train()
test_accu,_ = evaluate(datasets['test'],
args['batch_size'],
model_path=args['output_dir']+'/best_model.pth',
is_adv=args['domain_adapt'],
test=True
)
logging.info(f"\nHP search result: lambda_trans={lambda_trans}, lambda_rec={lambda_rec}, gamma={gamma}, num_no_adv/epoch={args['num_no_adv']}/{args['n_epoch']}, val accu={accu}, test accu={test_accu}")
if test_accu > best_accu:
best_accu = test_accu
best_lambda_rec = lambda_rec
best_lambda_trans = lambda_trans
best_gamma = gamma
subprocess.call(['cp',args['output_dir']+'/best_model.pth',args['output_dir']+'/hp_search_best_model.pth'])
# clean up best_model file
subprocess.call(["rm", args['output_dir'] + '/best_model.pth'])
logging.info(f"\nHyperparameter search finished. Best model has lambda_trans={best_lambda_trans}, lambda_rec={best_lambda_rec}, gamma={best_gamma}, num_no_adv/epoch={args['num_no_adv']}/{args['n_epoch']}. Val Macro F1={best_accu}\n")
logging.info('============ Evaluation on Test Data ============= \n')
test_accu,_ = evaluate(datasets['test'],
args['batch_size'],
model_path=args['output_dir']+'/hp_search_best_model.pth',
is_adv=args['domain_adapt'],
test=True
)
logging.info('Macro F1 of the %s TEST dataset with given (best) model: %f' % ('target', test_accu))
logging.info(f"Finished evaluating test data {args['test_domain']}. Time: {time.time()-start_time}")