Skip to content

Commit

Permalink
fixed perplexity
Browse files Browse the repository at this point in the history
  • Loading branch information
VeritasJoker committed Feb 15, 2024
1 parent 2665ca0 commit 9dc12c4
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions scripts/tfsemb_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def main():
args = arg_parser()
setup_environ(args)

if os.path.exists(args.base_df_file):
base_df = load_pickle(args.base_df_file)
else:
raise Exception("Base dataframe does not exist")
# if os.path.exists(args.base_df_file):
# base_df = load_pickle(args.base_df_file)
# else:
# raise Exception("Base dataframe does not exist")

# base_df_path = args.base_df_file.replace("661/embeddings", "777/pickles/embeddings")
# base_df = load_pickle(base_df_path)
base_df_path = args.base_df_file.replace("661/embeddings", "777/pickles/embeddings")
base_df = load_pickle(base_df_path)

utterance_df = select_conversation(args, base_df)
assert len(utterance_df) != 0, "Empty dataframe"
Expand All @@ -42,40 +42,43 @@ def main():
max_length = args.model.config.n_positions
except:
max_length = args.model.config.max_position_embeddings
stride = 2048

strides = [512, 1024, 2048, 4096]
encodings = torch.tensor([tuple(utterance_df.token_id.tolist())])
seq_len = encodings.size(1)

nlls = []
prev_end_loc = 0
model = args.model
device = args.device

for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone().to(device)
target_ids[:, :-trg_len] = -100
for stride in strides:
nlls = []
prev_end_loc = 0
model = args.model
device = args.device
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = (
end_loc - prev_end_loc
) # may be different from stride on last loop
input_ids = encodings[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone().to(device)
target_ids[:, :-trg_len] = -100

with torch.no_grad():
model = model.to(device)
model.eval()
outputs = args.model(input_ids, labels=target_ids)
with torch.no_grad():
model = model.to(device)
model.eval()
outputs = model(input_ids, labels=target_ids)

neg_log_likelihood = outputs.loss
neg_log_likelihood = outputs.loss

nlls.append(neg_log_likelihood)
nlls.append(neg_log_likelihood)

prev_end_loc = end_loc
if end_loc == seq_len:
break
prev_end_loc = end_loc
if end_loc == seq_len:
break

ppl = torch.exp(torch.stack(nlls).mean())
print(f"Emb: {args.embedding_type}, Length: {max_length}, Perplexity: {ppl}")
ppl = torch.exp(torch.stack(nlls).mean())
print(f"Emb: {args.embedding_type}, Stride: {stride}, Perplexity: {ppl}")

return


if __name__ == "__main__":
main()
main()

0 comments on commit 9dc12c4

Please sign in to comment.