-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2generate_process.py
58 lines (47 loc) · 2.12 KB
/
2generate_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from PIL import Image
from functools import partial
import argparse
import json
import pickle
import traceback
from tqdm import tqdm
from engine.utils import ProgramInterpreter
parser = argparse.ArgumentParser(description='Generate programs')
parser.add_argument('--input_dir', type=str, default='dataset/processed_data/test_programs.json', help='input data directory')
parser.add_argument('--output_dir', type=str, default='dataset/processed_data/processes.json', help='output data directory')
parser.add_argument('--image_dir', type=str, default='dataset/images/', help='input image directory')
parser.add_argument('--state_dir', type=str, default='dataset/states/', help='output state directory')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
if __name__ == '__main__':
interpreter = ProgramInterpreter(dataset='SME')
data = json.load(open(args.input_dir))
processes = {}
for sampled_id in tqdm(data.keys()):
sampled_data = data[sampled_id]
question = sampled_data['question']
image_id = sampled_data['image_id']
program = sampled_data['program']
try:
image = Image.open(args.image_dir + image_id + '.jpg')
except FileNotFoundError:
try:
image = Image.open(args.image_dir + image_id + '.png')
except:
traceback.print_exc()
continue
image.thumbnail((640, 640), Image.Resampling.LANCZOS)
init_state = dict(
IMAGE=image.convert('RGB')
)
try:
result, prog_state, pro_str = interpreter.execute(program, init_state, inspect='text')
except:
continue
processes[sampled_id] = {'question': question, 'image_id': image_id, 'program': program,
'answer': sampled_data['answer'], 'process': pro_str, 'predicted_answer': str(result)}
with open(args.state_dir + sampled_id + ".pkl", 'wb') as f:
pickle.dump(prog_state, f)
with open(args.output_dir, 'w') as f:
json.dump(processes, f)