From 5fe254dd64d37332347edc73738edcb56096183f Mon Sep 17 00:00:00 2001 From: Scott Yih Date: Thu, 1 Apr 2021 20:03:34 -0700 Subject: [PATCH] truncate the end if the sequence is too long... --- elq/main_dense.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/elq/main_dense.py b/elq/main_dense.py index e7e31048..7641f28f 100644 --- a/elq/main_dense.py +++ b/elq/main_dense.py @@ -229,9 +229,13 @@ def _process_biencoder_dataloader(samples, tokenizer, biencoder_params, logger): max_seq_len = 0 for sample in samples: samples_text_tuple - encoded_sample = [101] + tokenizer.encode(sample['text']) + [102] + # truncate the end if the sequence is too long... + encoded_sample = [101] + tokenizer.encode(sample['text'])[:biencoder_params["max_context_length"]-2] + [102] max_seq_len = max(len(encoded_sample), max_seq_len) samples_text_tuple.append(encoded_sample + [0 for _ in range(biencoder_params["max_context_length"] - len(encoded_sample))]) + + # print(samples_text_tuple) + tensor_data_tuple = [torch.tensor(samples_text_tuple)] tensor_data = TensorDataset(*tensor_data_tuple) sampler = SequentialSampler(tensor_data)