diff --git a/launcher_scripts/nemo_launcher/collections/eval_harness/lm_eval/models/nemo_gpt3.py b/launcher_scripts/nemo_launcher/collections/eval_harness/lm_eval/models/nemo_gpt3.py index b9c97486e8..09241b7d88 100755 --- a/launcher_scripts/nemo_launcher/collections/eval_harness/lm_eval/models/nemo_gpt3.py +++ b/launcher_scripts/nemo_launcher/collections/eval_harness/lm_eval/models/nemo_gpt3.py @@ -24,6 +24,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel from nemo.collections.nlp.modules.common.text_generation_utils import generate, get_computeprob_response from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.get_rank import is_global_rank_zero @@ -35,11 +36,11 @@ class RequestDataset(Dataset): - def __init__(self, requests, tokenizer) -> None: + def __init__(self, requests, tokenizer, max_length=2048) -> None: super().__init__() self.requests = requests self.tokenizer = tokenizer - self.max_length = 2048 + self.max_length = max_length def __len__(self): return len(self.requests) @@ -48,6 +49,9 @@ def __getitem__(self, index): context, continuation = self.requests[index] context_enc = self.tokenizer.text_to_ids(context) if isinstance(context, str) else context continuation_enc = self.tokenizer.text_to_ids(continuation) if isinstance(continuation, str) else continuation + if isinstance(self.tokenizer, SentencePieceTokenizer): + continuation_enc = continuation_enc[1:] + # sanity check assert len(context_enc) > 0 assert len(continuation_enc) > 0 @@ -167,12 +171,8 @@ def __init__(self, args, truncate=False, batch_size=1): self.model.eval() self.max_length = self.model.cfg.get("max_position_embeddings") - assert self.tokenizer.text_to_ids("hello\n\nhello") == [ - 31373, - 198, - 198, - 31373, - ], "Tokenizer text_to_ids is not working as expected." + self.pad_id = self.tokenizer.pad_id + self.eos_id = self.tokenizer.eos_id self.truncate = truncate self.batch_size = batch_size @@ -202,7 +202,8 @@ def loglikelihood(self, requests): """ def _loglikelihood(self, requests): - def pad_collate(batch, eos_id=50256): + def pad_collate(batch): + eos_id = self.eos_id tokens = [item[0] for item in batch] conti_lens = [item[1] for item in batch] lens = [len(token) - 1 for token in tokens] # fake delete last token by reducing input len @@ -241,7 +242,7 @@ def _collate(x): # used to reorder request and remove duplications return -len(toks), tuple(toks) reord = utils.Reorderer(requests, _collate) - request_ds = RequestDataset(reord.get_reordered(), self.model.tokenizer) + request_ds = RequestDataset(reord.get_reordered(), self.model.tokenizer, self.max_length) request_dl = DataLoader(request_ds, collate_fn=pad_collate, batch_size=self.batch_size, shuffle=False) def logits_to_results(batch, response): @@ -314,7 +315,7 @@ def loglikelihood_rolling(self, requests): utils.make_disjoint_window, utils.get_rolling_token_windows( token_list=self.tokenizer.text_to_ids(string), - prefix_token=50256, + prefix_token=self.eos_id, max_seq_len=self.max_length, context_len=1, ),