-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
61 lines (50 loc) · 1.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import re
def find_choice(text):
pattern = re.compile(r'[A-F]\)')
match = pattern.search(text)
if match:
return match.start()
else:
pattern = re.compile(r'[A-F]')
match = pattern.search(text)
if match:
return match.start()
else:
return -1
def is_ans_format(text: str):
if '不是正確答案' in text:
return False
elif '正確答案' in text:
return True
elif '不正確' in text:
return False
elif '正確' in text:
return True
elif 'A' in text or 'B' in text or 'C' in text or 'D' in text or 'E' in text:
return True
else:
return False
def check_ans(raw_response: str, answer: str):
raw_response_split = raw_response.strip().split('\n\n')
if is_ans_format(raw_response_split[0]):
prediction_text = raw_response_split[0]
else:
prediction_text = raw_response_split[-1]
choice_pos = find_choice(prediction_text)
if choice_pos == -1:
return False
else:
return prediction_text[choice_pos] == answer
def check_ans_cot(raw_response: str, answer: str):
raw_response_split = raw_response.strip().split('問題:')[0].strip().split('\n')
prediction_text = ''
for i in range(len(raw_response_split) - 1, -1, -1):
if is_ans_format(raw_response_split[i]):
prediction_text = ''.join(raw_response_split[i:])
break
ans_pos = max(prediction_text.find('正確答案'), 0)
choice_pos = find_choice(prediction_text[ans_pos:])
if choice_pos == -1:
return False
else:
return prediction_text[ans_pos + choice_pos] == answer