Skip to content

Commit

Permalink
Support mtbench
Browse files Browse the repository at this point in the history
  • Loading branch information
Xu Yuanchen committed Nov 8, 2023
1 parent 67f5331 commit d799831
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 13 deletions.
2 changes: 2 additions & 0 deletions applications/ColossalEval/colossal_eval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .gaokaobench import GaoKaoBenchDataset
from .longbench import LongBenchDataset
from .mmlu import MMLUDataset
from .mtbench import MTBenchDataset

__all__ = [
"AGIEvalDataset",
Expand All @@ -16,4 +17,5 @@
"LongBenchDataset",
"MMLUDataset",
"ColossalDataset",
"MTBenchDataset",
]
72 changes: 72 additions & 0 deletions applications/ColossalEval/colossal_eval/dataset/mtbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import copy
import json
import os
from collections import defaultdict
from typing import Dict, List

from colossal_eval.utils import get_json_list

from colossalai.logging import DistributedLogger

from .base import BaseDataset

default_inference_kwargs = {
"calculate_loss": False,
"all_classes": None,
"language": "English",
"pretrain": False,
"max_new_tokens": 1024,
"turns": 2,
}


class MTBenchDataset(BaseDataset):
"""
Dataset class for mt_bench dataset.
Data source: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl
This dataset class will convert the original dataset into the inference dataset.
"""

def __init__(self, path, logger, few_shot):
self.multiturn = True
self.dataset = self.load(path, logger, few_shot)

@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
dataset = {"test": defaultdict(dict)}

file_path = os.path.join(path, "question.jsonl")
ref_path = os.path.join(path, "reference_answer/gpt-4.jsonl")

reference = defaultdict(list)
ref_origin = get_json_list(ref_path)
for ref in ref_origin:
reference[ref["question_id"]] = ref["choices"][0]["turns"]

with open(file_path, "r", encoding="utf-8") as file:
for line in file:
question = json.loads(line)
category = question["category"]
turn_number = len(question["turns"])
data_point = {
"id": question["question_id"],
"dataset": "mtbench",
"split": "test",
"category": category,
"instruction": question["turns"],
"input": "",
"output": [],
"target": [""] * turn_number
if question["question_id"] not in reference
else reference[question["question_id"]],
}

if category in dataset["test"]:
dataset["test"][category]["data"].append(data_point)
else:
dataset["test"][category] = {
"data": [data_point],
"inference_kwargs": copy.deepcopy(default_inference_kwargs),
}

return dataset
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
from typing import Dict, List

import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
import numpy as np
import tqdm
from colossal_eval.utils import jdump

LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
CombinedMetrics = ["combined_single_choice_accuracy"]
GPTMetrics = ["mtbench_single_judge"]
OtherMetrics = [
"f1_score",
"f1_zh_score",
Expand All @@ -29,8 +32,9 @@ class DatasetEvaluator(object):
"""

def __init__(self):
pass
def __init__(self, config_path: str, save_path: str):
self.config_path = config_path
self.save_path = save_path

def _calculate_label_metrics(self, metric: str, category: str):
"""Calculate label-based metrics."""
Expand Down Expand Up @@ -156,6 +160,24 @@ def _calculate_other_metrics(self, metric: str, category: str):
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += total_score * weight

def _calculate_gpt_metrics(self, metric: str, category: str):
"""Calculate gpt metrics."""
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]

metric_method = eval("gpt_helper." + metric)

judgements, avg_ratings = metric_method(self.data[category]["data"], self.config_path)
self.judgements[category] = judgements

self.evaluation_results[metric][category] = (np.mean(avg_ratings), len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += np.mean(avg_ratings) * weight

for i in range(avg_ratings.shape[0]):
if f"{metric}_{i+1}" not in self.evaluation_results:
self.evaluation_results[f"{metric}_{i+1}"] = {cat: 0 for cat in (["ALL"] + self.categories)}
self.evaluation_results[f"{metric}_{i+1}"][category] = (avg_ratings[i], len(self.data[category]["data"]))
self.evaluation_results[f"{metric}_{i+1}"]["ALL"] += avg_ratings[i] * weight

def _calculate_loss_metrics(self, metric: str, category: str):
"""Calculate perplexity."""
if metric == "perplexity":
Expand Down Expand Up @@ -217,10 +239,20 @@ def _evaluate(self):
for category in self.suggested_categories[metric]:
self._calculate_combined_metrics(metric, category)
pbar.update(1)
elif metric in GPTMetrics:
for category in self.suggested_categories[metric]:
self._calculate_gpt_metrics(metric, category)
pbar.update(1)
elif metric in OtherMetrics:
for category in self.suggested_categories[metric]:
self._calculate_other_metrics(metric, category)
pbar.update(1)
else:
raise Exception(f"{metric} not supported.")

if self.judgements:
judgement_path = os.path.join(self.save_path, f"{self.model_name}_judgements.json")
jdump(self.judgements, judgement_path)

return self.evaluation_results

Expand All @@ -240,6 +272,7 @@ def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name
self.model_name = model_name
self.categories = list(data.keys())
self.metrics = metrics
self.judgements = {}

self.evaluation_results = {
metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Code adapted from https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge

import ast
import concurrent.futures
import copy
import json
import os
import re
import time
from typing import Any, Dict, List

import numpy as np
import openai
import tqdm

MODEL = "gpt-4"

API_MAX_RETRY = 16
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"

NEED_REF_CATS = ["math", "reasoning", "coding"]

one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")


def load_mt_prompts(prompt_file: str):
prompts = {}
with open(prompt_file) as fin:
for line in fin:
line = json.loads(line)
prompts[line["name"]] = line
return prompts


def get_mt_prompt(prompts: Dict[str, str], multiturn: bool, math: bool):
if math and multiturn:
return prompts["single-math-v1-multi-turn"]
elif math and not multiturn:
return prompts["single-math-v1"]
elif not math and multiturn:
return prompts["single-v1-multi-turn"]
elif not math and not multiturn:
return prompts["single-v1"]


def chat_compeletion_openai(messages: List[Dict], temperature: float = 0.0, max_tokens: int = 2048):
output = API_ERROR_OUTPUT
model = MODEL
for _ in range(API_MAX_RETRY):
try:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
n=1,
temperature=temperature,
max_tokens=max_tokens,
)
output = response["choices"][0]["message"]["content"]
break
except openai.error.OpenAIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)

return output


def get_mtbench_judgements(question: Dict[str, Any], prompts: Dict[str, str]):
id = question["id"]
judgement = {"id": id, "judgements": [], "ratings": []}
category = question["category"]
math = category in NEED_REF_CATS
turn_number = len(question["instruction"])

for num in range(turn_number):
assert (len(question["target"]) >= 1 and math) or not math
kwargs = {}
if num >= 1:
prompt = get_mt_prompt(prompts, multiturn=True, math=math)
if len(question["target"]) >= 1 and math:
kwargs = {f"ref_answer_{i+1}": question["target"][i] for i in range(len(question["target"]))}
user_prompt = prompt["prompt_template"].format(
question_1=question["instruction"][0],
question_2=question["instruction"][1],
answer_1=question["output"][0],
answer_2=question["output"][1],
**kwargs,
)
else:
prompt = get_mt_prompt(prompts, multiturn=False, math=math)
if len(question["target"]) >= 1 and math:
kwargs = {"ref_answer_1": question["target"][0]}
user_prompt = prompt["prompt_template"].format(
question=question["instruction"][0],
answer=question["output"][0],
**kwargs,
)

rating = -1
sys_prompt = prompt["system_prompt"]
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]

judgement_str = chat_compeletion_openai(messages, temperature=0.0, max_tokens=2048)
match = re.search(one_score_pattern, judgement_str)
if not match:
match = re.search(one_score_pattern_backup, judgement_str)
if match:
rating = ast.literal_eval(match.groups()[0])
else:
rating = -1

judgement["judgements"].append(judgement_str)
judgement["ratings"].append(rating)

return judgement


def mtbench_single_judge(data: List[Dict], config_path: str):
judgements = []

prompt_dir = os.path.dirname(config_path)
prompts = load_mt_prompts(os.path.join(prompt_dir, "mtbench_judge_prompts.jsonl"))

with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for i, question in enumerate(data):
future = executor.submit(get_mtbench_judgements, question, prompts)
futures.append(future)

for future in tqdm.tqdm(
concurrent.futures.as_completed(futures),
desc=f"MTBench single judge for {data[0]['category']}",
total=len(futures),
):
judgements.append(future.result())

judgements.sort(key=lambda x: x["id"])

judgements_by_id = {j["id"]: j for j in judgements}

data_to_dump = copy.deepcopy(data)

for d in data_to_dump:
id = d["id"]
d["judgements"] = judgements_by_id[id]["judgements"]
d["ratings"] = judgements_by_id[id]["ratings"]

avg_ratings = np.mean([j["ratings"] for j in judgements], axis=0)

return data_to_dump, avg_ratings
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
"ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"],
},
"mtbench": {"mtbench_single_judge": ["ALL"]},
}


Expand Down
10 changes: 8 additions & 2 deletions applications/ColossalEval/colossal_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,12 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b

self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}

turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
turn_desc = "" if turn == 0 else f"-turn{turn}"

bar = tqdm(
range(math.ceil(len(data) / self.batch_size)),
desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
Expand Down Expand Up @@ -384,7 +387,10 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b

for j in range(len(batch_prompt)):
if not pretrain:
answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(answers[i + j]["output"], list):
answers[i + j]["output"].append(batch_decodes[j].strip())
else:
answers[i + j]["output"] = batch_decodes[j].strip()

if isinstance(scores, torch.Tensor):
answers[i + j]["softmax_over_choices"] = probs[j]
Expand Down
24 changes: 20 additions & 4 deletions applications/ColossalEval/colossal_eval/utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def get_batch_prompt(
for b in batch:
few_shot_prefix = ""
if few_shot_data is not None:
assert not isinstance(b["instruction"], list), print(
f"When performing few-shot, {b['dataset']} shouldn't be a multiturn dataset."
)
# For few-shot, only need input. Otherwise use instruction (in AGIEval).
query_text = b["input"] if b.get("input", "") != "" else b["instruction"]

Expand All @@ -181,11 +184,24 @@ def get_batch_prompt(
raise Exception("When using few-shot, target answer should be a string.")

few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
else:
query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]

conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None)
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
conv.append_message(conv.roles[1], None)
else:
if not isinstance(b["instruction"], list):
query_text = (
b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
)
conv.append_message(conv.roles[0], query_text)
conv.append_message(conv.roles[1], None)
else:
assert len(b["instruction"]) >= len(b["output"]) + 1
cur_turns = len(b["output"])
for turn in range(cur_turns):
conv.append_message(conv.roles[0], b["instruction"][turn])
conv.append_message(conv.roles[1], b["output"][turn])
conv.append_message(conv.roles[0], b["instruction"][cur_turns])
conv.append_message(conv.roles[1], None)

batch_prompt.append(conv.get_prompt())

Expand Down
Loading

0 comments on commit d799831

Please sign in to comment.