-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_bad_questions.py
104 lines (87 loc) · 3.38 KB
/
generate_bad_questions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
'''
Use OpenAI GPT to generate bad questions
'''
import os
import json
from collections import defaultdict
from utils import *
from tqdm import tqdm
from prompt_llm import *
def seperate_turns(dialouge):
'''
Seperate the dialouge into turns
'''
# iterate over the dialouge
collect_turns = []
clear_line = ''
for line in dialouge.split('\n'):
if 'User:' in line:
collect_turns.append(clear_line)
clear_line = ''
clear_line += line
# print(collect_turns)
# remove first element - empty string
collect_turns.pop(0)
# return
return collect_turns
def construct_input_data(turns):
'''
construct data for inputting to the question generator
'''
all_input_data = []
for ctr, turn in enumerate(turns):
# strip the Assistant turn
current_turn = turn[:turn.find('Assistant:')].strip()
# iterate over all previous turns
all_prev_turns = ''
for prev_turn in turns[:ctr]:
all_prev_turns += prev_turn
# construct input data
input_data = all_prev_turns + current_turn + 'Assistant:\n'
all_input_data.append(input_data)
return all_input_data
def main():
# set api key
set_api_key()
data_path = 'socratic-debugging-benchmark/socratic_debugging_benchmark/v2_sigcse'
train_path = os.path.join(data_path, 'train')
store_results = defaultdict(list)
# check if results file already exists
if os.path.exists('bad_questions.json'):
with open('bad_questions.json', 'r') as infile:
store_results_load = json.load(infile)
# copy store_results_load into store_results
for key, value in store_results_load.items():
store_results[key] = value
for ctr, tr_file in tqdm(enumerate(os.listdir(train_path)), total=len(os.listdir(train_path))):
if tr_file in store_results.keys():
continue
tr_file_path = os.path.join(train_path, tr_file)
with open(tr_file_path, 'r') as f:
conversation_data = f.read()
# extract problem meta data - everything until </bug_fixes>
problem_meta_data = conversation_data[:conversation_data.find('</bug_fixes>')+len('</bug_fixes>')].strip()
# print(problem_meta_data)
dialouge = extract_text_in_tags(conversation_data, '<dialogue>', '</dialogue>')
# seperate turns
turns = seperate_turns(dialouge)
# print(turns)
# construct input data
all_input_conversation = construct_input_data(turns)
# print('#### Input Data ####')
# print(all_input_conversation[2])
for cctr, conversation in enumerate(all_input_conversation):
# construct input prompt
input_prompt = problem_meta_data + '\n\n<dialogue>' + conversation + '\nOUTPUT:\n'
# print('#### Input Prompt ####')
# print(input_prompt)
# generate bad question
llm_response = prompt_bad_question_generation(input_prompt)
# print('#### LLM Response ####')
# print(llm_response)
store_results[tr_file].append(llm_response)
# save results
with open('bad_questions.json', 'w') as outfile:
json.dump(store_results, outfile, indent=6)
if __name__ == '__main__':
main()