-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
61 lines (48 loc) · 1.56 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Dict, Any
from datasets import load_dataset
def get_tldr_pair_dict(post: Dict[str, Any]):
"""Get the a dict which contains a pair of summaries for posts from TLDR.
Args:
post: a dictionary of the HuggingFace dataset format
Returns:
post_dict: a dictionary of the following format
{
'post': a string of the post
'chosen': a string of the chonse summary
'rejected': a string of the rejected summary
}
"""
post_text = f"""TTILE: {post['info']['title']}
TEXT: {post['info']['post']}
"""
chosen_text = post['summaries'][post['choice']]['text']
chosen_text = chosen_text.replace('\n\n', ' ')
rejected_text = post['summaries'][1 - post['choice']]['text']
rejected_text = rejected_text.replace('\n\n', ' ')
return {
'post': post_text,
'chosen': chosen_text,
'rejected': rejected_text
}
def get_tldr_post_list(split:str = 'validation'):
"""Get the list of posts from the TLDR dataset.
The post in the return list is organised as follows:
{
'post': a string of the post
'chosen': a string of the chonse summary
'rejected': a string of the rejected summary
}
Args:
split: str, one of ['train', 'validation', 'test']
Returns:
post_list: list of strings
"""
dataset = load_dataset(
'openai/summarize_from_feedback',
'comparisons',
split='validation'
)
post_list = []
for post in dataset:
post_list.append(get_tldr_pair_dict(post))
return post_list