-
Notifications
You must be signed in to change notification settings - Fork 72
/
run_eval.py
118 lines (98 loc) · 4.37 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import os
import warnings
warnings.filterwarnings("ignore")
parser = argparse.ArgumentParser(description='Eval Arguments.')
parser.add_argument('--method',
type=str,
choices=['direct', 'cot', 'react', 'rewoo'],
help='Paradigm to use')
parser.add_argument('--dataset',
type=str,
choices=["hotpot_qa", "trivia_qa", "gsm8k", "physics_question",
"sports_understanding", "strategy_qa", "sotu_qa"],
help='Dataset to use')
parser.add_argument('--sample_size',
type=int,
default=10,
help='Sample size to eval')
parser.add_argument('--toolset',
nargs='+',
default=['Google', 'Wikipedia', 'WolframAlpha', 'Calculator', 'LLM'],
help='Tools available to ALMs.')
parser.add_argument('--base_lm',
type=str,
default='text-davinci-003',
help='Base language model to use. Can be text-davinci-003, gpt-3.5-turbo or directory to alpca-lora')
parser.add_argument('--planner_lm',
type=str,
help='Base LM for Planner. Default to base_lm')
parser.add_argument('--solver_lm',
type=str,
help='Base LM for Solver. Default to base_lm')
parser.add_argument('--save_result',
action='store_true',
help='Save result to file')
parser.add_argument('--seed',
type=int,
default=2024,
help='Random seed')
parser.add_argument('--key_path',
type=str,
default='./keys/',
help='Path where you store your openai.key and serpapi.key. Default to ./key/')
args = parser.parse_args()
with open(os.path.join(args.key_path, 'openai.key'), 'r') as f:
os.environ["OPENAI_API_KEY"] = f.read().strip()
with open(os.path.join(args.key_path, 'serpapi.key'), 'r') as f:
os.environ["SERPAPI_API_KEY"] = f.read().strip()
from algos.PWS import *
from algos.notool import IO, CoT
from algos.react import ReactBase, ReactExtraTool
from utils.DataLoader import DataLoader
from utils.Evaluator import Evaluator
from utils.util import *
def save_data(dataset, data, save_path):
dataset["preds"] = data["preds"]
dataset["em"] = data["em"]
dataset["f1"] = data["f1"]
dataset["acc"] = data["acc"]
dataset["wall_time"] = data["wall_time"]
dataset["total_tokens"] = data["total_tokens"]
dataset["steps"] = data["steps"]
dataset["tool_cost"] = data["tool_cost"]
dataset["token_cost"] = data["token_cost"]
dataset["total_cost"] = data["total_cost"]
dataset.to_csv(save_path, index=False)
return dataset
def main(args):
dataset = DataLoader(args.dataset, seed=args.seed).load(sample_size=args.sample_size)
if args.method == 'direct':
method = IO(model_name=args.base_lm)
eval = Evaluator(args.dataset, dataset, method)
elif args.method == 'cot':
method = CoT(model_name=args.base_lm, fewshot=DEFAULT_EXEMPLARS_COT[args.dataset])
eval = Evaluator(args.dataset, dataset, method)
elif args.method == 'react':
if args.dataset in ['hotpot_qa', 'trivia_qa']:
method = ReactBase(model_name=args.base_lm, fewshot=DEFAULT_EXEMPLARS_REACT[args.dataset], verbose=False)
else:
method = ReactExtraTool(model_name=args.base_lm, available_tools=args.toolset,
fewshot=DEFAULT_EXEMPLARS_REACT[args.dataset], verbose=False)
eval = Evaluator(args.dataset, dataset, method)
elif args.method == 'rewoo':
if args.planner_lm is None:
args.planner_lm = args.base_lm
if args.solver_lm is None:
args.solver_lm = args.base_lm
method = PWS_Base(planner_model=args.planner_lm, solver_model=args.solver_lm,
fewshot=DEFAUL_EXEMPLARS_PWS[args.dataset], available_tools=args.toolset)
eval = Evaluator(args.dataset, dataset, method)
else:
raise NotImplementedError
responses, data = eval.run()
if args.save_result:
save_data(dataset, data, f'./results/eval_{args.dataset}_{args.method}_{args.base_lm}.csv')
print(responses)
if __name__ == '__main__':
main(args)