-
Notifications
You must be signed in to change notification settings - Fork 0
/
Evaluate.py
86 lines (68 loc) · 2.55 KB
/
Evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import glob
import os
from pathlib import Path
from sklearn.metrics import balanced_accuracy_score
from tqdm import tqdm
import Input
from Model.ConvolutionalHate import ConvolutionalHate
from Model.CustomMaxHate import CustomMaxHate
from Model.CustomHate import CustomHate
from Model.LSTMHate import LSTMHate
from Output import outputResult
PRETRAINED_MODEL = {
'en': 'bert-base-cased',
'es': 'dccuchile/bert-base-spanish-wwm-cased'
}
MODEL_MAP = {
'CustomHate': CustomHate,
'CustomMaxHate': CustomMaxHate,
'LSTMHate': LSTMHate,
'ConvolutionalHate': ConvolutionalHate,
}
parser = argparse.ArgumentParser(description="Feed classifier")
parser.add_argument("-l", "--language", help="Language", required=True)
parser.add_argument("-p", "--pretrained_model",
help="Pretrained Model", required=False)
parser.add_argument("-m", "--model", help="Model", required=True)
parser.add_argument("-mp", "--model_path",
help="Model CKPT path", required=False)
parser.add_argument("-i", "--input_path",
help="Data to predict", required=False)
parser.add_argument("-o", "--output_path",
help="Path for the prediction", required=False)
def get_model_prediction(model, input):
output = model(input)
y_hat = output.argmax()
return int(y_hat)
if __name__ == "__main__":
args = parser.parse_args()
if args.input_path:
input_path = args.input_path
else:
input_path = f"data_test/{args.language}/"
try:
pretrained_model = args.pretrained_mode if args.pretrained_mod else PRETRAINED_MODEL[
args.language]
except AttributeError:
pretrained_model = PRETRAINED_MODEL[args.language]
if args.model_path:
model_path = args.model_path
else:
models_path = f'lightning_logs/{args.language}/{pretrained_model.split("/")[-1]}/{args.model}'
model_path = glob.glob(models_path + '/**/*.ckpt', recursive=True)[0]
if args.output_path:
output_path = args.output_path
else:
output_path = os.path.dirname(model_path)
model = MODEL_MAP[args.model].load_from_checkpoint(
checkpoint_path=model_path, pretrained_model_name=pretrained_model)
data = Input.get_data_dict(input_path, with_label=False)
Path(output_path).mkdir(parents=True, exist_ok=True)
for feed in tqdm(data):
outputResult(
id=feed['id'],
type=get_model_prediction(model, feed['input']),
lang=args.language,
prepath=output_path
)