Skip to content

Commit

Permalink
use HFMultimodalLM.chat_template ; restore tasks/__init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Sep 13, 2024
1 parent 4623768 commit 9ddb2ec
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
4 changes: 1 addition & 3 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,7 @@ def _adjust_config(task_dict):
model_source=model,
model_args=model_args,
system_instruction=system_instruction,
# TODO: change this back
# chat_template=lm.chat_template(apply_chat_template),
chat_template=None,
chat_template=lm.chat_template(apply_chat_template),
fewshot_as_multiturn=fewshot_as_multiturn,
)

Expand Down
7 changes: 1 addition & 6 deletions lm_eval/models/vllm_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs)
# TODO: support text-only reqs
res = []

def _collate(x):
Expand Down Expand Up @@ -214,8 +214,6 @@ def _collate(x):
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)

### Up to here: was identical to non-multimodal HFLM generate_until ###

for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)

Expand All @@ -226,7 +224,6 @@ def _collate(x):
contexts
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
# TODO: could we upstream this workaround to HF?
### this part onward: same as HFLM ###

# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
Expand Down Expand Up @@ -258,8 +255,6 @@ def _collate(x):
else:
max_gen_toks = self.max_gen_toks

### end stuff that's entirely copied verbatim from HFLM ###

max_ctx_len = self.max_length - max_gen_toks

inputs = self.tok_batch_multimodal_encode(
Expand Down
14 changes: 6 additions & 8 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,12 @@ def _load_individual_task_or_group(
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping:
def _load_task(config, task, yaml_path=None):
def _load_task(config, task):
if "include" in config:
config = {
**utils.load_yaml_config(
yaml_path=yaml_path,
yaml_path=None,
yaml_config={"include": config.pop("include")},
mode="full",
),
Expand All @@ -275,6 +274,7 @@ def _load_task(config, task, yaml_path=None):
task_object.config.task = config["task"]
else:
task_object = ConfigurableTask(config=config)

return {task: task_object}

def _get_group_and_subtask_from_config(config):
Expand Down Expand Up @@ -316,7 +316,6 @@ def _process_group_config(config, update_config=None):
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
yaml_path = self._get_yaml_path(group_name.group)
else:
if self._name_is_tag(name_or_config):
fn = partial(
Expand All @@ -335,8 +334,7 @@ def _process_group_config(config, update_config=None):

if isinstance(name_or_config, dict):
if self._config_is_task(name_or_config):
# name = name_or_config.pop("task")
name = name_or_config["task"]
name = name_or_config.pop("task")
if update_config is not None:
name_or_config = {**name_or_config, **update_config}
# If the name is registered as a group
Expand Down Expand Up @@ -380,7 +378,7 @@ def _process_group_config(config, update_config=None):
}
else:
task_config = name_or_config
return _load_task(task_config, task=name, yaml_path=yaml_path)
return _load_task(task_config, task=name)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
Expand All @@ -391,7 +389,6 @@ def _process_group_config(config, update_config=None):
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
Expand Down Expand Up @@ -651,4 +648,5 @@ def get_task_dict(
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict))

return final_task_dict

0 comments on commit 9ddb2ec

Please sign in to comment.