-
Notifications
You must be signed in to change notification settings - Fork 0
/
annotator.py
48 lines (37 loc) · 1.54 KB
/
annotator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""This module contains the `PreferenceFunction` class to annotate the
preferences over generations.
Eatch batch is a dictionary of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generations1': str, # text of the 1st generation
'generations1_token_ids': tensor, # token ids of the 1st generation
'generations1_attention_mask': tensor,
# attention mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generations2': str, # text of the 2nd generation
'generations2_token_ids': tensor, # token ids of the 2nd generation
'generations2_attention_mask': tensor,
# attention mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}
"""
from typing import List, Dict, Any
class Annotator:
def __init__(self) -> None:
raise NotImplementedError
class TensorAnnotator(Annotator):
# inputs are PyTorch tensors
def __init__(self) -> None:
raise NotImplementedError
class RewardAnnotator(TensorAnnotator):
# use a reward model to annotate a given PyTorch tensor
def __init__(self) -> None:
raise NotImplementedError
class TextAnnotator(Annotator):
# inputs are strings
def __init__(self) -> None:
raise NotImplementedError