-
Notifications
You must be signed in to change notification settings - Fork 0
/
batchfactory.py
398 lines (343 loc) · 14.2 KB
/
batchfactory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
"""This module contains the `BatchFactory` class to create batches of training data.
Each batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generations1': str, # text of the 1st generation
'generations1_token_ids': tensor, # token ids of the 1st generation
'generations1_attention_mask': tensor,
# attention mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generations2': str, # text of the 2nd generation
'generations2_token_ids': tensor, # token ids of the 2nd generation
'generations2_attention_mask': tensor,
# attention mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}
"""
import random
import torch
from transformers import PreTrainedModel, AutoTokenizer
from typing import Dict, Optional, Union
import DataLoader
from samplereweighter import SampleReweighter
from annotator import Annotator
class BatchFactory:
def __init__(
self,
dataset_name: str, # e..g ['hh', 'shp']
tokenizer, # Huggingface tokenizer object
generator: Union[None, PreTrainedModel] = None,
# None for offline data, otherwise a Huggingface model
annotator: Union[None, Annotator] = None,
# None for offline data, otherwise a Huggingface model
reweighter: Union[None, SampleReweighter] = None,
# None for on reweighting, otherwise a SampleReweighter object
split: str = 'train',
batch_size: int = 1,
max_length: int = 512,
max_prompt_length: int = 128,
n_epochs: Optional[int] = None,
n_examples: Optional[int] = None,
human_prefix: str = '\n<|user|>\n', # marks start of human's turn
human_suffix: str = '', # marks end of human's turn
# marks start of assistant's turn
assistant_prefix: str = '\n<|assistant|>\n',
assistant_suffix: str = '', # marks end of assistant's turn
seed:int = 0,
device: str = 'cpu', # device for the generator
**kwargs
) -> None:
torch.manual_seed(seed)
random.seed(seed)
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.split = split
self.batch_size = batch_size
self.max_length = max_length
self.max_prompt_length = max_prompt_length
self.human_prefix = human_prefix
self.human_suffix = human_suffix
self.assistant_prefix = assistant_prefix
self.assistant_suffix = assistant_suffix
self.kwargs = kwargs
self.generator = generator
self.annotator = annotator
self.reweighter = reweighter
assert n_epochs is not None or n_examples is not None, \
"Must specify either n_epochs or n_examples"
self.n_epochs = n_epochs
self.epoch_idx = 0
self.n_examples = n_examples
self.online = self._get_online_flag()
self.pairwise = self._get_pariwise_flag()
self.reweight = self._get_reweight_flag()
self._check_type()
self.data_loader = self._get_dataloader()
self.device = device
def _check_type(self):
raise NotImplementedError
def __iter__(self):
raise NotImplementedError
def _get_dataloader(self):
raise NotImplementedError
def _get_online_flag(self):
if self.generator is not None and self.annotator is not None:
return True
elif self.generator is None and self.annotator is None:
return False
else:
raise ValueError(
'Must specify both generator and annotator or neither'
)
def _get_pariwise_flag(self):
class_name = type(self).__name__
if 'pairwise' in class_name.lower():
return True
else:
return False
def _get_reweight_flag(self):
if self.reweighter is not None:
return True
else:
return False
class OfflineBatchFactory(BatchFactory):
"""BatchFactory offline data with or without reweighting.
This class is to generate offline batches for training the policy model.
The batches are from the `DataLoader` class. If `reweight` is True, the
`SampleReweighter` class is used to adjust the weights of generations in
a batch. Otherwise, the batches are yielded directly from the data loader.
"""
def _check_type(self):
assert self.online is False, \
'Cannot use OfflineBatchFactory for online data'
self._offline_check_type()
def __iter__(self):
if self.reweight:
batch = next(self.data_loader.__iter__())
batch = self.reweighter(batch)
return batch
else:
return self.data_loader.__iter__()
def _offline_check_type(self):
raise NotImplementedError
class SFTBatchFactory(OfflineBatchFactory):
"""BatchFactory for SFT data.
Each batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generations1': str, # text of prompt + 1st generation
'generation1_response_only': str, # text of the 1st generation only
'generations1_token_ids': tensor, # token ids of the 1st generation
'generations1_attention_mask': tensor, # attn mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
}
"""
def _get_dataloader(self):
return DataLoader.SFTDataLoader(
self.dataset_name,
self.tokenizer,
self.split,
self.batch_size,
self.max_length,
self.max_prompt_length,
self.batch_size,
self.n_epochs,
self.n_examples,
self.human_prefix,
self.human_suffix,
self.assistant_prefix,
self.assistant_suffix,
**self.kwargs
)
def _offline_check_type(self):
assert self.reweight is False, \
'SFTBatchFactory should not be used with reweighting.'
class OfflinePointwiseBatchFactory(OfflineBatchFactory):
"""BatchFactory for offline data with pointwise feedback.
Each batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generations1': str, # text of prompt + 1st generation
'generation1_response_only': str, # text of the 1st generation only
'generations1_token_ids': tensor, # token ids of the 1st generation
'generations1_attention_mask': tensor, # attn mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
}
"""
def _get_dataloader(self):
return DataLoader.PointwiseFeedbackDataLoader(
self.dataset_name,
self.tokenizer,
self.split,
self.batch_size,
self.max_length,
self.max_prompt_length,
self.batch_size,
self.n_epochs,
self.n_examples,
human_prefix=self.human_prefix,
human_suffix=self.human_suffix,
assistant_prefix=self.assistant_prefix,
assistant_suffix=self.assistant_suffix,
**self.kwargs
)
def _offline_check_type(self):
assert self.pairwise is False, \
'OfflinePointwiseBatchFactory cannot be used for pairwise feedback'
class OfflinePairwiseBatchFactory(OfflineBatchFactory):
"""BatchFactory for offline data with pointwise feedback.
Each batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generation1': str, # text of prompt + 1st generation
'generation1_response_only': str, # text of the 1st generation only
'generation1_token_ids': tensor, # token ids of the 1st generation
'generation1_attention_mask': tensor, # attn mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generation2': str, # text of prompt + 2nd generation
'generation2_response_only': str, # text of the 2nd generation only
'generation2_token_ids': tensor, # token ids of the 2nd generation
'generation2_attention_mask': tensor, # attn mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}
"""
def _get_dataloader(self):
return DataLoader.PointwiseFeedbackDataLoader(
self.dataset_name,
self.tokenizer,
self.split,
self.batch_size,
self.max_length,
self.max_prompt_length,
self.batch_size,
self.n_epochs,
self.n_examples,
human_prefix=self.human_prefix,
human_suffix=self.human_suffix,
assistant_prefix=self.assistant_prefix,
assistant_suffix=self.assistant_suffix,
**self.kwargs
)
def _offline_check_type(self):
assert self.pairwise is True, \
'OfflinePairwiseBatchFactory must be used for pairwise feedback'
class OnlineBatchFactory(BatchFactory):
"""BatchFactor with online annotator and generator.
This class is to generate online batches fro training the policy model.
The data loader used in this class is `PromptDataLoader` which yields
only the prompts in the dataset. The responses are generated by the
generator on the fly. The responses are then annotated by the annotator
before being returned as a batch.
"""
def _check_type(self):
assert self.online is True, \
'Cannot use OfflinePairwiseBatchFactory for online data'
self._online_check_type()
def _get_dataloader(self):
return DataLoader.PromptDataLoader(
self.dataset_name,
self.tokenizer,
self.split,
self.batch_size,
self.max_length,
self.max_prompt_length,
self.batch_size,
self.n_epochs,
self.n_examples,
self.human_prefix,
self.human_suffix,
self.assistant_prefix,
self.assistant_suffix,
**self.kwargs
)
def _get_batch(self):
raise NotImplementedError
@torch.no_grad() # no need to compute gradients for on-policy sampling
def _get_sample_from_generator(
self,
prompt_batch: Dict[str, torch.Tensor],
id: int = 1,
):
"""Generate responses for the given prompt batch.
Args:
prompt_batch: a batch of prompts of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
}
Returns:
A batch of responses of the following format:
{
'generations{id}': str, # text of the generation
'generations{id}_token_ids': tensor, token ids of the generation
'generations{id}_attention_mask': tensor,
# attention mask of the 1st generation
}
"""
assert self.tokenizer.padding_side == 'left', \
'Only left padding is supported for sampling a batch from generator'
outputs = self.generator.generate(
input_ids=prompt_batch['prompt_input_ids'].to('cuda'),
attention_mask=prompt_batch['prompt_attention_mask'].to('cuda'),
max_new_tokens=self.max_length - self.max_prompt_length,
return_dict_in_generate=True,
eos_token_id = self.tokenizer.eos_token_id,
pad_token_id = self.tokenizer.pad_token_id,
)
token_ids = outputs.sequences
# TODO: the following mask needs to be checked and updated
attention_mask = (token_ids != self.tokenizer.pad_token_id).long()
return token_ids, attention_mask
def _get_feedback_from_annotator(
self,
generations_token_ids: torch.Tensor,
generations_attention_mask: torch.Tensor,
):
raise NotImplementedError
def __iter__(self):
if self.reweight:
batch = self._get_batch()
batch = self.reweighter(batch)
return batch
else:
return self._get_batch()
def _online_check_type(self):
assert self.annotator is not None, \
'OnlineBatchFactory must have an annotator'
assert self.generator is not None, \
'OnlineBatchFactory must have a generator'
class OnlinePointwiseBatchFactory(OnlineBatchFactory):
def _get_batch(self):
prompt_batch = next(self.data_loader.__iter__())
generations1_token_ids, generations1_attention_mask = \
self._get_sample_from_generator(prompt_batch, id=1)
return self.annotator.annotate(
prompt_batch,
generations1_token_ids,
generations1_attention_mask,
id=1,
)
def _online_check_type(self):
self.super()._online_check_type()
assert self.pairwise is False, \
'OnlinePointwiseBatchFactory cannot be used for pairwise feedback'
class OnlinePairwiseBatchFactory(OnlineBatchFactory):
def _online_check_type(self):
self.super()._online_check_type()
assert self.pairwise is True, \
'OnlinePairwiseBatchFactory must be used for pairwise feedback'