-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransfer_pretrained.py
94 lines (83 loc) · 3.23 KB
/
transfer_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import torch
from utils.load_datasets import load_MR, load_Semeval2017A
from training import get_metrics_report
# DATASET = 'Semeval2017A'
DATASET = 'MR'
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LABELS_MAPPING = {
'siebert/sentiment-roberta-large-english': {
'POSITIVE': 'positive',
'NEGATIVE': 'negative',
},
'cardiffnlp/twitter-roberta-base-sentiment': {
'LABEL_0': 'negative',
'LABEL_1': 'neutral',
'LABEL_2': 'positive',
},
'textattack/bert-base-uncased-imdb': {
'LABEL_0': 'negative',
'LABEL_1': 'positive',
},
'textattack/bert-base-uncased-yelp-polarity': {
'LABEL_0': 'negative',
'LABEL_1': 'positive',
},
'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis': {
'positive': 'positive',
'neutral': 'neutral',
'negative': 'negative',
},
'Seethal/sentiment_analysis_generic_dataset': {
'LABEL_0': 'negative',
'LABEL_1': 'neutral',
'LABEL_2': 'positive',
},
'cardiffnlp/twitter-xlm-roberta-base-sentiment': {
'Positive': 'positive',
'Neutral': 'neutral',
'Negative': 'negative',
},
'j-hartmann/sentiment-roberta-large-english-3-classes': {
'positive': 'positive',
'neutral': 'neutral',
'negative': 'negative',
}
}
if __name__ == '__main__':
# load the raw data
if DATASET == "Semeval2017A":
pretrained_models = ['cardiffnlp/twitter-roberta-base-sentiment',
'Seethal/sentiment_analysis_generic_dataset',
'j-hartmann/sentiment-roberta-large-english-3-classes']
#'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis',
#'cardiffnlp/twitter-xlm-roberta-base-sentiment'
elif DATASET == "MR":
pretrained_models = ['siebert/sentiment-roberta-large-english',
'textattack/bert-base-uncased-imdb',
'textattack/bert-base-uncased-yelp-polarity']
else:
raise ValueError("Invalid dataset")
for PRETRAINED_MODEL in pretrained_models:
if DATASET == 'Semeval2017A':
X_train, y_train, X_test, y_test = load_Semeval2017A()
else:
X_train, y_train, X_test, y_test = load_MR()
# encode labels
le = LabelEncoder()
le.fit(list(set(y_train)))
y_train = le.transform(y_train)
y_test = le.transform(y_test)
n_classes = len(list(le.classes_))
# define a proper pipeline
sentiment_pipeline = pipeline("sentiment-analysis", model=PRETRAINED_MODEL)
y_pred = []
for x in tqdm(X_test):
# TODO: Main-lab-Q6 - get the label using the defined pipeline
result = sentiment_pipeline(x)[0]
label = result['label']
y_pred.append(LABELS_MAPPING[PRETRAINED_MODEL][label])
y_pred = le.transform(y_pred)
print(f'\nDataset: {DATASET}\nPre-Trained model: {PRETRAINED_MODEL}\nTest set evaluation\n{get_metrics_report([y_test], [y_pred])}')