Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF: switch conditional checks to self.backend from AUTO_MODEL_CLASS #2353

Merged
merged 12 commits into from
Oct 8, 2024
80 changes: 40 additions & 40 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class HFLM(TemplateLM):
def __init__(
self,
pretrained: Union[str, transformers.PreTrainedModel],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main",
subfolder: Optional[str] = None,
Expand Down Expand Up @@ -90,7 +90,6 @@ def __init__(
**kwargs,
) -> None:
super().__init__()

# optionally: take in an already-initialized transformers.PreTrainedModel
if not isinstance(pretrained, str):
eval_logger.warning(
Expand Down Expand Up @@ -164,7 +163,7 @@ def __init__(
trust_remote_code=trust_remote_code,
)

# determine which of 'causal' and 'seq2seq' backends to use
# determine which of 'causal' and 'seq2seq' backends to use for HF models
self._get_backend(
config=self.config, backend=backend, trust_remote_code=trust_remote_code
)
Expand Down Expand Up @@ -287,7 +286,7 @@ def __init__(

def _get_accelerate_args(
self,
parallelize: bool = None,
parallelize: Optional[bool] = None,
device_map: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
Expand Down Expand Up @@ -441,31 +440,26 @@ def tokenizer_name(self) -> str:
def _get_backend(
self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False,
) -> None:
"""
Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
model type to be used.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.

**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
"""
# escape hatch: if we're using a subclass that shouldn't follow
# the default _get_backend logic,
# then skip over the method.
# TODO: this seems very much undesirable in some cases--our code in HFLM
# references AutoModelForCausalLM at times to check for equality
if self.AUTO_MODEL_CLASS is not None:
return

assert backend in ["default", "causal", "seq2seq"]

if backend != "default":
# if we've settled on non-default backend, use that manually
if backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
self.backend = backend
elif backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
self.backend = backend
eval_logger.info(
f"Overrode HF model backend type, and using type '{backend}'"
)
Expand All @@ -478,33 +472,40 @@ def _get_backend(
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
self.backend = "seq2seq"
eval_logger.info(f"Using model type '{backend}'")
elif (
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
self.backend = "causal"
baberabb marked this conversation as resolved.
Show resolved Hide resolved
eval_logger.info(f"Using model type '{backend}'")
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
"Setting backend to causal"
)
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
# then we default to assuming AutoModelForCausalLM
self.backend = "causal"
eval_logger.info(
f"Model type cannot be determined. Using default model type '{backend}'"
)

assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
return None
if self.AUTO_MODEL_CLASS is None:
if self.backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif self.backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM

def _get_config(
self,
pretrained: str,
revision: str = "main",
trust_remote_code: bool = False,
) -> None:
"""Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
Expand Down Expand Up @@ -703,7 +704,7 @@ def _detect_batch_size(self, requests=None, pos: int = 0):
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
if self.backend == "seq2seq":
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones(
(batch_size, length), device=self.device
Expand Down Expand Up @@ -754,7 +755,7 @@ def tok_encode(

# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token
}
Expand Down Expand Up @@ -782,7 +783,7 @@ def tok_batch_encode(
self.tokenizer.padding_side = padding_side

add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}

encoding = self.tokenizer(
Expand Down Expand Up @@ -860,14 +861,14 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs):
def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
Expand Down Expand Up @@ -990,8 +991,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
requests,
sort_fn=_collate,
group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
if self.backend == "causal" and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
Expand Down Expand Up @@ -1048,14 +1048,14 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

# when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
Expand Down Expand Up @@ -1095,11 +1095,11 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
# TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat(
padding_len_inp, inps
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
if self.backend == "causal"
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
Expand Down Expand Up @@ -1265,10 +1265,10 @@ def _collate(req: Tuple[str, dict]):
max_gen_toks = self.max_gen_toks

# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length

Expand All @@ -1295,7 +1295,7 @@ def _collate(req: Tuple[str, dict]):
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
if self.backend == "causal":
cont_toks = cont_toks[context_enc.shape[1] :]

s = self.tok_decode(cont_toks)
Expand Down
Loading