Skip to content

Commit

Permalink
Fix eval harness on next nvgpt (#117)
Browse files Browse the repository at this point in the history
* Fix evaluation harness

* Fix few issues
  • Loading branch information
yaoyu-33 authored Sep 8, 2023
1 parent a82887b commit 52d3c08
Showing 1 changed file with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.modules.common.text_generation_utils import generate, get_computeprob_response
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero
Expand All @@ -35,11 +36,11 @@


class RequestDataset(Dataset):
def __init__(self, requests, tokenizer) -> None:
def __init__(self, requests, tokenizer, max_length=2048) -> None:
super().__init__()
self.requests = requests
self.tokenizer = tokenizer
self.max_length = 2048
self.max_length = max_length

def __len__(self):
return len(self.requests)
Expand All @@ -48,6 +49,9 @@ def __getitem__(self, index):
context, continuation = self.requests[index]
context_enc = self.tokenizer.text_to_ids(context) if isinstance(context, str) else context
continuation_enc = self.tokenizer.text_to_ids(continuation) if isinstance(continuation, str) else continuation
if isinstance(self.tokenizer, SentencePieceTokenizer):
continuation_enc = continuation_enc[1:]

# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
Expand Down Expand Up @@ -167,12 +171,8 @@ def __init__(self, args, truncate=False, batch_size=1):
self.model.eval()

self.max_length = self.model.cfg.get("max_position_embeddings")
assert self.tokenizer.text_to_ids("hello\n\nhello") == [
31373,
198,
198,
31373,
], "Tokenizer text_to_ids is not working as expected."
self.pad_id = self.tokenizer.pad_id
self.eos_id = self.tokenizer.eos_id

self.truncate = truncate
self.batch_size = batch_size
Expand Down Expand Up @@ -202,7 +202,8 @@ def loglikelihood(self, requests):
"""

def _loglikelihood(self, requests):
def pad_collate(batch, eos_id=50256):
def pad_collate(batch):
eos_id = self.eos_id
tokens = [item[0] for item in batch]
conti_lens = [item[1] for item in batch]
lens = [len(token) - 1 for token in tokens] # fake delete last token by reducing input len
Expand Down Expand Up @@ -241,7 +242,7 @@ def _collate(x): # used to reorder request and remove duplications
return -len(toks), tuple(toks)

reord = utils.Reorderer(requests, _collate)
request_ds = RequestDataset(reord.get_reordered(), self.model.tokenizer)
request_ds = RequestDataset(reord.get_reordered(), self.model.tokenizer, self.max_length)
request_dl = DataLoader(request_ds, collate_fn=pad_collate, batch_size=self.batch_size, shuffle=False)

def logits_to_results(batch, response):
Expand Down Expand Up @@ -314,7 +315,7 @@ def loglikelihood_rolling(self, requests):
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tokenizer.text_to_ids(string),
prefix_token=50256,
prefix_token=self.eos_id,
max_seq_len=self.max_length,
context_len=1,
),
Expand Down

0 comments on commit 52d3c08

Please sign in to comment.