-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsamplereweighter.py
34 lines (28 loc) · 1.24 KB
/
samplereweighter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""This module contains the `SampleReweighter` class to adjust the
weights of generations in a batch.
Eatch batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generations1': str, # text of the 1st generation
'generations1_token_ids': tensor, # token ids of the 1st generation
'generations1_attention_mask': tensor,
# attention mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generations2': str, # text of the 2nd generation
'generations2_token_ids': tensor, # token ids of the 2nd generation
'generations2_attention_mask': tensor,
# attention mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}
"""
class SampleReweighter(object):
def __init__(self) -> None:
raise NotImplementedError
class TensorAnnotator(SampleReweighter):
# inputs are PyTorch tensors
def __init__(self) -> None:
raise NotImplementedError