diff --git a/docs/chat-template-readme.md b/docs/chat-template-readme.md new file mode 100644 index 0000000000..d127083d7d --- /dev/null +++ b/docs/chat-template-readme.md @@ -0,0 +1,26 @@ +# Chat Template Delimiter Handling Update + +## Overview +This change modifies how delimiters are handled when applying chat templates in the request construction process for likelihood and multiple-choice based tasks. When `apply_chat_template` is set to `True`, the target delimiter is now set to an empty string instead of using the configured delimiter. + +## Background +By default, the system uses a target delimiter (typically a whitespace " ") between the context and target text when constructing prompts. The full string is constructed as: +``` +doc_to_text(doc) + target_delimiter + doc_to_target(doc) +``` + +While this worked well for base models where we wanted the model to predict a single whitespace followed by the answer, chat models have their own formatting conventions that handle spacing differently. + +## The Change +- When `apply_chat_template=True`, the target delimiter is now empty ("") instead of the default whitespace +- This prevents interference between chat template formatting and the default delimiter system +- Particularly important for multiple choice tasks where the template itself handles spacing + +## Example +``` +# Before (with default delimiter " ") +Question: What color is the sky?\nAnswer: blue + +# After +Question: What color is the sky?\nAnswer:blue +``` diff --git a/lm_eval/api/task.py b/lm_eval/api/task.py index f0be8db99a..56a1ad1601 100644 --- a/lm_eval/api/task.py +++ b/lm_eval/api/task.py @@ -449,6 +449,7 @@ def build_all_requests( doc=doc, ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), + apply_chat_template=apply_chat_template, ) if not isinstance(inst, list): @@ -1301,6 +1302,8 @@ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: def construct_requests( self, doc: dict, ctx: str, **kwargs ) -> Union[List[Instance], Instance]: + apply_chat_template = kwargs.pop("apply_chat_template", False) + aux_arguments = None if self.OUTPUT_TYPE == "loglikelihood": @@ -1310,6 +1313,8 @@ def construct_requests( elif self.OUTPUT_TYPE == "multiple_choice": choices = self.doc_to_choice(doc) target_delimiter = self.config.target_delimiter + if apply_chat_template: + target_delimiter = "" if self.multiple_input: # If there are multiple inputs, choices are placed in the ctx cont = self.doc_to_target(doc) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index d0c1a19a65..cea7d754c7 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -400,6 +400,11 @@ def evaluate( eval_logger.setLevel(getattr(logging, f"{verbosity}")) + if apply_chat_template: + eval_logger.warning( + "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details." + ) + # tracks all Instances/requests a model must generate output on. requests = defaultdict(list) # stores the amount to pad out reqs per req. type so that diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index 2097c2cfea..964dabb245 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple, Union +import jinja2 import torch import torch.nn.functional as F import transformers @@ -1344,9 +1345,20 @@ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: """ Method to apply a chat template to a list of chat history between user and model. """ - return self.tokenizer.apply_chat_template( - chat_history, tokenize=False, add_generation_prompt=True - ) + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, tokenize=False, add_generation_prompt=True + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, tokenize=False, add_generation_prompt=True + ) + + return chat_templated def get_model_info(self) -> dict: """ diff --git a/lm_eval/tasks/leaderboard/math/_template_yaml b/lm_eval/tasks/leaderboard/math/_template_yaml index 9c404b0c5e..bc97d202f3 100644 --- a/lm_eval/tasks/leaderboard/math/_template_yaml +++ b/lm_eval/tasks/leaderboard/math/_template_yaml @@ -18,7 +18,7 @@ metric_list: higher_is_better: true num_fewshot: 4 metadata: - version: 1.0 + version: 2.0 dataset_kwargs: trust_remote_code: true fewshot_config: diff --git a/lm_eval/tasks/leaderboard/math/utils.py b/lm_eval/tasks/leaderboard/math/utils.py index e3ebcf991b..607be3016c 100644 --- a/lm_eval/tasks/leaderboard/math/utils.py +++ b/lm_eval/tasks/leaderboard/math/utils.py @@ -17,6 +17,9 @@ ) +INVALID_ANSWER = "[invalidanswer]" + + # taken from # https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py def doc_to_text(doc: dict) -> str: @@ -70,7 +73,10 @@ def process_results(doc: dict, results: List[str]) -> Dict[str, int]: unnormalized_answer = get_unnormalized_answer(candidates) answer = normalize_final_answer(unnormalized_answer) - if is_equiv(answer, doc["answer"]): + if answer == INVALID_ANSWER: + return {"exact_match": 0} + + if answer.strip() == doc["answer"].strip() or is_equiv(answer, doc["answer"]): retval = 1 else: retval = 0 @@ -112,17 +118,19 @@ def last_boxed_only_string(string: str) -> Optional[str]: def remove_boxed(s: str) -> str: - if "\\boxed " in s: - left = "\\boxed " - assert s[: len(left)] == left - return s[len(left) :] - - left = "\\boxed{" + try: + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] - assert s[: len(left)] == left - assert s[-1] == "}" + left = "\\boxed{" - return s[len(left) : -1] + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + except AssertionError: + return INVALID_ANSWER class timeout: @@ -146,7 +154,7 @@ def is_equiv(x1: str, x2: str) -> bool: x1 and x2 are normalized latex string """ try: - with timeout(seconds=5): + with timeout(seconds=1): try: parsed_x1 = parse_latex(x1) parsed_x2 = parse_latex(x2) @@ -185,7 +193,6 @@ def is_equiv(x1: str, x2: str) -> bool: def get_unnormalized_answer(text: str) -> str: - INVALID_ANSWER = "[invalidanswer]" end_seq = "I hope it is correct." text += end_seq match = re.search(