-
Notifications
You must be signed in to change notification settings - Fork 1
/
reward_model.py
190 lines (170 loc) · 7.62 KB
/
reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from typing import List
import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer, T5EncoderModel, T5PreTrainedModel, T5Model, PreTrainedModel, BitsAndBytesConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from reward_model_dataset import LABELS, format_input, USE_LIKERT
from utils import device
peft_config = LoraConfig(
modules_to_save=["score"],
r=16,
lora_alpha=16,
lora_dropout=0.1,
task_type="SEQ_CLS",
inference_mode=False,
)
bnb_config = BitsAndBytesConfig(
llm_int8_skip_modules=["score"],
load_in_8bit=True,
)
class T5Classifier(T5PreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = T5EncoderModel(config)
self.classifier = nn.Sequential(
nn.Linear(config.d_model, config.d_model),
nn.Tanh(),
nn.Dropout(p=config.classifier_dropout),
nn.Linear(config.d_model, config.num_labels)
)
self.loss_fn = nn.BCEWithLogitsLoss()
self.post_init()
self.model_parallel = False
def forward(self, input_ids, attention_mask, labels = None, **kwargs):
model_outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
seq_lens = attention_mask.sum(dim=1)
masked_hidden_states = model_outputs.last_hidden_state * attention_mask.unsqueeze(2)
avg_hidden_states = masked_hidden_states.sum(dim=1) / seq_lens.unsqueeze(1)
logits = self.classifier(avg_hidden_states)
if labels is not None:
loss = self.loss_fn(logits, labels)
else:
loss = None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
class T5ClassifierWDecoder(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = T5Model(config)
self.classifier = nn.Sequential(
nn.Linear(config.d_model, config.d_model),
nn.Tanh(),
nn.Dropout(p=config.classifier_dropout),
nn.Linear(config.d_model, config.num_labels)
)
self.loss_fn = nn.BCEWithLogitsLoss()
self.post_init()
self.model_parallel = False
def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels = None, **kwargs):
decoder_input_ids = nn.functional.pad(decoder_input_ids, (1, 0), value=self.config.decoder_start_token_id)
decoder_attention_mask = nn.functional.pad(decoder_attention_mask, (1, 0), value=1)
model_output = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
seq_lens = decoder_attention_mask.sum(dim=1)
hidden_states = model_output.last_hidden_state[torch.arange(input_ids.shape[0]), seq_lens - 1]
logits = self.classifier(hidden_states)
if labels is not None:
loss = self.loss_fn(logits, labels)
else:
loss = None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
class EnsembleModel(nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, labels = None, **kwargs):
all_logits = [
model(input_ids, attention_mask, labels=labels).logits
for model in self.models
]
logits = torch.stack(all_logits).mean(dim=0)
if labels is not None:
loss = self.loss_fn(logits, labels)
else:
loss = None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
def get_tokenizer(base_model: str):
tokenizer = AutoTokenizer.from_pretrained(base_model)
if "llama" in base_model or "meta-math" in base_model:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return tokenizer
def get_reward_model(model_name: str, base_model: str, tokenizer, enc_dec: bool, test: bool):
if "t5" in base_model:
if enc_dec:
model = T5ClassifierWDecoder.from_pretrained(model_name, num_labels=len(LABELS), classifier_dropout=0.0).to(device)
else:
model = T5Classifier.from_pretrained(model_name, num_labels=len(LABELS), classifier_dropout=0.0).to(device)
elif "llama" in base_model or "meta-math" in base_model:
model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=len(LABELS),
problem_type="multi_label_classification",
pad_token_id=tokenizer.pad_token_id,
quantization_config=None if test else bnb_config,
# Higher precision for non-quantized parameters helps training accuracy and doesn't hurt performance
# Lower precision at test time improves speed and only marginally hurts performance
torch_dtype=torch.float16 if test else torch.float32,
device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1
if test:
model = PeftModel.from_pretrained(model, model_name).merge_and_unload()
else:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=len(LABELS), problem_type="multi_label_classification").to(device)
return model
def get_ensemble(model_names: List[str], base_model: str, tokenizer, enc_dec: bool, test: bool):
models = [
get_reward_model(model_name, base_model, tokenizer, enc_dec, test)
for model_name in model_names
]
return EnsembleModel(models)
def get_check_correct_batch_fn(model_name: str, base_model: str, test: bool):
use_t5_decoder = False
enc_dec = ("t5" in base_model) and use_t5_decoder
tokenizer = get_tokenizer(base_model)
rm = get_ensemble(model_name.split(","), base_model, tokenizer, enc_dec, True)
rm.eval()
def check_correct_batch(src_meta_datas: List[dict], pred_texts: List[str]):
nonlocal tokenizer, rm
with torch.no_grad():
batch_size = 4
all_scores = []
for batch_start_idx in range(0, len(pred_texts), batch_size):
meta_data_batch = src_meta_datas[batch_start_idx : batch_start_idx + batch_size]
pred_text_batch = pred_texts[batch_start_idx : batch_start_idx + batch_size]
inputs = tokenizer(
[format_input(meta_data, pred_text) for meta_data, pred_text in zip(meta_data_batch, pred_text_batch)],
padding=True, truncation=True, return_tensors="pt"
).to(device)
logits = rm(**inputs).logits
scores = torch.sigmoid(logits)
if not USE_LIKERT:
scores[:, :2] = 1 - scores[:, :2]
if not test:
scores = scores[:, 0] * scores.mean(dim=1)
all_scores.append(scores)
return torch.concat(all_scores).cpu()
return check_correct_batch