-
Notifications
You must be signed in to change notification settings - Fork 4
/
get_test_score_scienceqa.py
47 lines (36 loc) · 1.48 KB
/
get_test_score_scienceqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import json
import sys
def compute_accuracy(path):
with open(path, 'r') as f:
data = json.load(f)
correct_answers = 0
total_questions = len(data)
for item in data:
if item['pred_ans'] == item['gt_ans']:
correct_answers += 1
return correct_answers / total_questions
def find_latest_subdir(base_dir):
subdirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
latest_subdir = max(subdirs, key=os.path.getmtime)
return latest_subdir
def save_accuracy_to_json(path, accuracy):
with open(path, 'w') as f:
json.dump({"test_accuracy": accuracy}, f, indent=4)
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python script_name.py <int_value>")
sys.exit(1)
int_value = int(sys.argv[1])
# TODO: Fix this base_path to each local environment
base_path = f"/input/results/scienceqa/scienceqa_{int_value}"
latest_dir = find_latest_subdir(base_path)
json_path = os.path.join(latest_dir, "result/test_scienceqa_result.json")
if os.path.exists(json_path):
accuracy = compute_accuracy(json_path)
print(f"Accuracy: {accuracy * 100:.2f}%")
# Save accuracy to a new JSON file in the same directory
accuracy_json_path = os.path.join(latest_dir, "result/test_accuracy.json")
save_accuracy_to_json(accuracy_json_path, accuracy)
else:
print(f"JSON file not found at {json_path}")