-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_completions.py
102 lines (79 loc) · 3.33 KB
/
evaluate_completions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import os
import sys
from datetime import datetime
from functools import partial
import pandas as pd
import torch
from huggingface_hub import login as hf_login
from tqdm import tqdm
from baselines.ppl_filter import Filter
def get_df(results_dir, df_name):
folder_name = "_".join(df_name.split("_")[:-2])
if not df_name.endswith(".csv"):
df_name = df_name + ".csv"
path = os.path.join(results_dir, folder_name, df_name)
df = pd.read_csv(path)
return df, path
def main():
parser = argparse.ArgumentParser(description="Оценка завершений с использованием PPL фильтра.")
parser.add_argument("--df_name", type=str, required=True, help="Имя CSV-файла DataFrame")
parser.add_argument(
"--results_dir", type=str, required=True, help="Каталог, где хранятся результаты"
)
parser.add_argument(
"--model_id", type=str, required=True, help="ID модели для использования в guard модели"
)
parser.add_argument(
"--base_path", type=str, required=False, help="Базовый путь для униграмм и биграмм"
)
parser.add_argument("--batch_size", type=int, default=1, help="Размер батча для обработки")
args = parser.parse_args()
hf_login(token=os.environ.get("HUGGINGFACE_API_KEY"))
if args.base_path:
base_path = args.base_path
else:
base_path = os.environ.get("WORK_DIR") + "/ngrams_results_final/joined_final"
df, df_path = get_df(args.results_dir, args.df_name)
requests = df["behavior"].apply(lambda x: str(x).strip()).tolist()
responses = df["final_completion"].apply(lambda x: str(x).strip()).tolist()
batch_prompts = list(zip(requests, responses))
results = []
results_PPL_filter = []
load_local_filter = partial(
Filter,
unigram_path=os.path.join(
base_path, "arity=1", "df_gutenberg_unigrams_dict_normalized_hashed.parquet"
),
bigram_path=os.path.join(
base_path, "arity=2", "df_gutenberg_bigrams_dict_normalized_hashed.parquet"
),
verbose_guard=False,
dummy_load=False,
guard_model_device="auto",
judge_return_output=True,
)
local_filter = load_local_filter(guard_model_id=args.model_id)
local_filter.guard_tokenizer.pad_token = local_filter.guard_tokenizer.eos_token
local_filter.guard_model_device = "cuda"
batch_size = args.batch_size
for i in tqdm(range(0, len(batch_prompts), batch_size), desc="Evaluating completions"):
batch = batch_prompts[i : i + batch_size]
for user_message, assistant_message in batch:
out = local_filter.judge(
user_message=user_message,
assistant_message=assistant_message,
user_context=None,
max_new_tokens=5,
)
filter_passed = local_filter.apply_filter(
user_message, 8, "perplexity", -38276.122748721245
)
results.append(out[0])
results_PPL_filter.append(filter_passed)
df["judge_output"] = results
df["filter_passed"] = results_PPL_filter
df.to_csv(df_path, index=False)
print("The file is saved to", df_path)
if __name__ == "__main__":
main()