-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsu_eval.py
133 lines (107 loc) · 3.61 KB
/
su_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import sys
import json
import time
import argparse
import torch
import datetime
from sklearn.metrics import recall_score, precision_score, accuracy_score
import numpy as np
from tqdm import tqdm
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from lib.mapper import Mapper
from lib.predictor import su_inference
from lib.eval_utils import *
def generate_default_json_name():
now = datetime.datetime.now()
return now.strftime('%m%d_%H%M')
if __name__=='__main__':
import warnings
warnings.filterwarnings("ignore")
# input arg
parser = argparse.ArgumentParser(description='Run LLaVA model inference.')
parser.add_argument(
'--json_file', type=str,
required=True,
help='Path to the image file'
)
parser.add_argument(
'--prompt_file', type=str,
default='./su_prompts/prompt.json',
help='Prompt file path'
)
parser.add_argument(
'--make_json', type=bool,
default=False,
help='make an output file'
)
input_args = parser.parse_args()
# load images
with open(input_args.json_file,'r') as j:
data_dict=json.load(j)
# define input args
args = type('Args', (), {
"model_path": "liuhaotian/llava-v1.6-vicuna-7b",
"model_base": None,
"model_name": get_model_name_from_path("liuhaotian/llava-v1.6-vicuna-7b"),
"query": None,
"conv_mode": None,
"image_file": None,
"sep": ",",
"temperature": 0.1,
"top_p": 0.99,
"num_beams": 1,
"max_new_tokens": 32
})()
# load pre-trained model
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name)
print(f'\n ** finish loading model {model_name} ** \n')
# load query prompt
with open(input_args.prompt_file,'r') as f:
query_dict=json.load(f)
print("query dict : ")
for i,j in query_dict.items():
print(i,'->',j )
recall,precision,accuracy,latency=[],[],[],[]
result_dict={}
for sample, gt in tqdm(data_dict.items()):
# update new data sample
args.image_file=sample
# inference with returning a dict of answer text
start_time = time.time()
answer_dict=su_inference(args, model_name, tokenizer, model, image_processor, None, query_dict)
end_time = time.time()
execution_time = end_time - start_time
latency.append(execution_time)
# print(f"\n{sample} --> su_inference with: {execution_time:.2f} sec\n")
# inference with returning a dict of boolean
mapper=Mapper(answer_dict)
bool_dict=mapper.answer2bool()
# merge sub-sce to output dict
output_dict=mapper.merge_bool()
# convert bool dict to onehot
pred=bool2binary(output_dict)
# record
result_dict[sample]={
'answer_dict':answer_dict,
'bool_dict':bool_dict,
'pred':pred,
'ground_truth':gt
}
# output a json file
if input_args.make_json:
json_path = f'./su_data/outputs/{generate_default_json_name()}.json'
with open(json_path,'w') as j:
json.dump(result_dict,j,indent=4)
print(f'\nrecord results -> {json_path} \n')
# evaluate
eval_from_json(json_path)