From 61211f3bacc34626d7c07b0dbd375645f57a7a34 Mon Sep 17 00:00:00 2001 From: Phil Yeres Date: Tue, 25 Feb 2020 11:47:15 -0500 Subject: [PATCH] Address max sequence length issue (#1002) (#1006) * add wrapper to apply truncation strategy * update interface to require keyword arg for s2 (w/ PEP 3102) * update tests to call boundary_token_fn w/ new req kwarg * update calls to boundary_token_fn to include req. s2 kwarg * update apply_standard_boundary_tokens to require keyword arg for s2 (w/ PEP 3102) * get max model input sizes from models, config to truncate to min of max model input size and max_seq_len * update trunc strategy to allow left- or right-trunc * add model processing interface trunc tests * raise exception if truncation is needed but strategy is unspecified. * expect ValueError with unspecified trunc strat in test * update _apply_boundary_tokens_with_trunc_strategy docstring * simplify/reduce logging during truncation * add option to trunc both s1 and s2 evenly * add test for trunc both s1 and s2 evenly * add reasonable truncation args for most tasks --- .../modules.py | 18 +- jiant/preprocess.py | 155 +++++++++++++++++- jiant/tasks/edge_probing.py | 2 +- jiant/tasks/qa.py | 58 +++++-- jiant/tasks/tasks.py | 119 ++++++++++---- jiant/utils/utils.py | 2 +- ...test_huggingface_transformers_interface.py | 12 +- tests/test_preprocess.py | 123 +++++++++++++- 8 files changed, 428 insertions(+), 61 deletions(-) diff --git a/jiant/huggingface_transformers_interface/modules.py b/jiant/huggingface_transformers_interface/modules.py index d8957e0cd..8b10998e7 100644 --- a/jiant/huggingface_transformers_interface/modules.py +++ b/jiant/huggingface_transformers_interface/modules.py @@ -189,7 +189,7 @@ def get_seg_ids(self, token_ids, input_mask): return seg_ids @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): """ A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or token sequence pair for most tasks. @@ -274,7 +274,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # BERT-style boundary token padding on string token sequences if s2: s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] @@ -331,7 +331,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # RoBERTa-style boundary token padding on string token sequences if s2: s = [""] + s1 + ["", ""] + s2 + [""] @@ -385,7 +385,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # ALBERT-style boundary token padding on string token sequences if s2: s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] @@ -448,7 +448,7 @@ def __init__(self, args): self._SEG_ID_SEP = 3 @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # XLNet-style boundary token marking on string token sequences if s2: s = s1 + [""] + s2 + ["", ""] @@ -506,7 +506,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # OpenAI-GPT-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -568,7 +568,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # GPT-2-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -630,7 +630,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # TransformerXL-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -697,7 +697,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, s2=None, get_offset=False): + def apply_boundary_tokens(s1, *, s2=None, get_offset=False): # XLM-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] diff --git a/jiant/preprocess.py b/jiant/preprocess.py index 27e3b6e18..815c383d6 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -740,53 +740,202 @@ class ModelPreprocessingInterface(object): """ + @staticmethod + def _apply_boundary_tokens(apply_boundary_tokens, max_tokens): + """ + Takes a function that applies boundary tokens and modifies it to respect a max_seq_len arg. + + Parameters + ---------- + apply_boundary_tokens : function + Takes a function that adds boundary tokens and implements the common boundary token + adding function interface demonstated in HuggingfaceTransformersEmbedderModule. + max_tokens : int + The maximum number of tokens to allow in the output sequence. + + Returns + ------- + apply_boundary_tokens_with_trunc_strategy : function + Function w/ the common boundary token adding interface demonstrated in + HuggingfaceTransformersEmbedderModule, but also implementing a truncation strategy. + """ + + def _apply_boundary_tokens_with_trunc_strategy( + *args, trunc_strategy=None, trunc_side=None, **kwargs + ): + """ + Calls the apply_boundary_tokens function provided in the parent function and, if the + output exceeds the max number of tokens, applies a truncation strategy to the inputs, + then re-applies the apply_boundary_tokens function on the truncated inputs and returns + the result. + + Parameters + ---------- + args : tuple + see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in + huggingface_transformers_interface.modules + trunc_strategy : str + Which strings to truncate. Options: + `trunc_s1`: select s1 for truncation. + `trunc_s2`: select s2 for truncation. + `trunc_both`: truncate s1 and s2 equally. + trunc_side : str + Options are `right` or `left`. Indicates which side of the sequence to remove from. + e.g., if `left`, then the first element to be removed from ["A", "B", "C"] is "A". + kwargs : dict + see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in + huggingface_transformers_interface.modules + + Returns + ------- + seq_w_boundry_tokens : List[str] + List of tokens returned by applying the specified apply_boundary_tokens function + to the inputs and using the specified truncation strategy. + + Raises + ------ + ValueError + If sequence requires truncation but trunc_strategy or trunc_side are unspecified + or invalid. + + """ + seq_w_boundry_tokens = apply_boundary_tokens(*args, **kwargs) + # if after calling the apply_boundary_tokens function the number of tokens is greater + # than the model's max, a truncation strategy can reduce the number of tokens: + num_excess_tokens = len(seq_w_boundry_tokens) - max_tokens + if num_excess_tokens > 0: + if not (trunc_strategy and trunc_side): + raise ValueError( + "Input(s) length will exceed model capacity or max_seq_len. " + + "Adjust settings as necessary, or, to automatically truncate " + + "inputs in `apply_boundary_tokens`, call `apply_boundary_tokens` " + + "with `trunc_strategy` and `trunc_side` keyword args." + ) + if trunc_strategy == "trunc_s2": + s2 = kwargs["s2"] + if trunc_side == "right": + s2_truncated = s2[:-num_excess_tokens] + elif trunc_side == "left": + s2_truncated = s2[num_excess_tokens:] + log.info( + "Before truncation s2 length = " + + str(len(s2)) + + ", after truncation s2 length = " + + str(len(s2_truncated)) + ) + assert len(s2_truncated) > 0, "After truncation, s2 length would be 0." + args = list(args) + kwargs["s2"] = s2_truncated + return apply_boundary_tokens(*args, **kwargs) + elif trunc_strategy == "trunc_both": + s1 = args[0] + s2 = kwargs["s2"] + if trunc_side == "right": + s1_truncated = s1[: -num_excess_tokens // 2] + s2_truncated = s2[: -num_excess_tokens // 2] + elif trunc_side == "left": + s1_truncated = s1[num_excess_tokens // 2 :] + s2_truncated = s2[num_excess_tokens // 2 :] + log.info( + "Before truncation s1 length = " + + str(len(s1)) + + " and s2 length = " + + str(len(s2)) + + ". After truncation s1 length = " + + str(len(s1_truncated)) + + " and s2 length = " + + str(len(s2_truncated)) + ) + assert len(s1_truncated) > 0, "After truncation, s1 length would be 0." + assert len(s2_truncated) > 0, "After truncation, s2 length would be 0." + args = list(args) + args[0] = s1_truncated + kwargs["s2"] = s2_truncated + return apply_boundary_tokens(*args, **kwargs) + elif trunc_strategy == "trunc_s1": + s1 = args[0] + if trunc_side == "right": + s1_truncated = s1[:-num_excess_tokens] + elif trunc_side == "left": + s1_truncated = s1[num_excess_tokens:] + log.info( + "Before truncation s1 length = " + + str(len(s1)) + + ", after truncation s1 length = " + + str(len(s1_truncated)) + ) + assert len(s1_truncated) > 0, "After truncation, s1 length would be 0." + args = list(args) + args[0] = s1_truncated + return apply_boundary_tokens(*args, **kwargs) + else: + raise ValueError(trunc_strategy + " is not a valid truncation strategy.") + else: + return seq_w_boundry_tokens + + return _apply_boundary_tokens_with_trunc_strategy + def __init__(self, args): boundary_token_fn = None lm_boundary_token_fn = None + max_pos: int = None if args.input_module.startswith("bert-"): from jiant.huggingface_transformers_interface.modules import BertEmbedderModule boundary_token_fn = BertEmbedderModule.apply_boundary_tokens + max_pos = BertTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("roberta-"): from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens + max_pos = RobertaTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("albert-"): from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens + max_pos = AlbertTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("xlnet-"): from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens + max_pos = XLNetTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("openai-gpt"): from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens + max_pos = OpenAIGPTTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("gpt2"): from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens + max_pos = GPT2Tokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("transfo-xl-"): from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens + max_pos = TransfoXLTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("xlm-"): from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens + max_pos = XLMTokenizer.max_model_input_sizes.get(args.input_module, None) else: boundary_token_fn = utils.apply_standard_boundary_tokens - self.boundary_token_fn = boundary_token_fn + self.max_tokens = min(x for x in [max_pos, args.max_seq_len] if x is not None) + + self.boundary_token_fn = self._apply_boundary_tokens(boundary_token_fn, self.max_tokens) + if lm_boundary_token_fn is not None: - self.lm_boundary_token_fn = lm_boundary_token_fn + self.lm_boundary_token_fn = self._apply_boundary_tokens( + lm_boundary_token_fn, args.max_seq_len + ) else: - self.lm_boundary_token_fn = boundary_token_fn + self.lm_boundary_token_fn = self.boundary_token_fn from jiant.models import input_module_uses_pair_embedding, input_module_uses_mirrored_pair diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py index 6821441ef..b6607382b 100644 --- a/jiant/tasks/edge_probing.py +++ b/jiant/tasks/edge_probing.py @@ -182,7 +182,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() # already space-tokenized by Moses tokens = model_preprocessing_interface.boundary_token_fn( - tokens + tokens, trunc_strategy="trunc_s1", trunc_side="right" ) # apply model-appropriate variants of [cls] and [sep]. text_field = sentence_to_text_field(tokens, indexers) diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py index 24da6d68f..7b2492908 100644 --- a/jiant/tasks/qa.py +++ b/jiant/tasks/qa.py @@ -120,11 +120,16 @@ def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx): d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn(para, question + answer) + inp = model_preprocessing_interface.boundary_token_fn( + para, s2=question + answer, trunc_strategy="trunc_s1", trunc_side="right" + ) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: d["psg"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(passage), indexers + model_preprocessing_interface.boundary_token_fn( + passage, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["qst"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(question), indexers @@ -308,11 +313,16 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn(psg, qst) + inp = model_preprocessing_interface.boundary_token_fn( + psg, s2=qst, trunc_strategy="trunc_s1", trunc_side="right" + ) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: d["psg"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(psg), indexers + model_preprocessing_interface.boundary_token_fn( + psg, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["qst"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(qst), indexers @@ -458,12 +468,19 @@ def _make_instance(example): if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn( - example["passage"], example["question"], get_offset=True + example["passage"], + s2=example["question"], + trunc_strategy="trunc_s1", + trunc_side="right", + get_offset=True, ) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["passage"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers + model_preprocessing_interface.boundary_token_fn( + example["passage"], trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["question"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(example["question"]), indexers @@ -614,12 +631,19 @@ def _make_instance(example): if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn( - example["passage"], example["question"], get_offset=True + example["passage"], + s2=example["question"], + trunc_strategy="trunc_s1", + trunc_side="right", + get_offset=True, ) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["passage"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers + model_preprocessing_interface.boundary_token_fn( + example["passage"], trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["question"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(example["question"]), indexers @@ -885,11 +909,16 @@ def _make_instance(question, choices, label, id_str): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(question), indexers + model_preprocessing_interface.boundary_token_fn( + question, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(question, choice) + model_preprocessing_interface.boundary_token_fn( + question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -983,11 +1012,16 @@ def _make_instance(context, choices, label, id_str): d["context_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["context"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(context), indexers + model_preprocessing_interface.boundary_token_fn( + context, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(context, choice) + model_preprocessing_interface.boundary_token_fn( + context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index a21a6cb34..512a0b37d 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -108,18 +108,25 @@ def _make_instance(input1, input2, labels, idx): d = {} d["sent1_str"] = MetadataField(" ".join(input1)) if model_preprocessing_interface.model_flags["uses_pair_embedding"] and is_pair: - inp = model_preprocessing_interface.boundary_token_fn(input1, input2) + inp = model_preprocessing_interface.boundary_token_fn( + input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" + ) d["inputs"] = sentence_to_text_field(inp, indexers) d["sent2_str"] = MetadataField(" ".join(input2)) if ( model_preprocessing_interface.model_flags["uses_mirrored_pair"] and is_symmetrical_pair ): - inp_m = model_preprocessing_interface.boundary_token_fn(input1, input2) + inp_m = model_preprocessing_interface.boundary_token_fn( + input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" + ) d["inputs_m"] = sentence_to_text_field(inp_m, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) if input2: d["input2"] = sentence_to_text_field( @@ -788,7 +795,10 @@ def _make_instance(input1, labels, tagids): """ from multiple types in one column create multiple fields """ d = {} d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["sent1_str"] = MetadataField(" ".join(input1)) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) @@ -1621,11 +1631,16 @@ def _make_instance(input1, input2, label, idx, lex_sem, pr_ar_str, logic, knowle """ from multiple types in one column create multiple fields """ d = {} if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn(input1, input2) + inp = model_preprocessing_interface.boundary_token_fn( + input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" + ) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -1829,12 +1844,17 @@ def _make_instance(input1, input2, labels, idx, pair_id): d = {} d["sent1_str"] = MetadataField(" ".join(input1)) if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn(input1, input2) + inp = model_preprocessing_interface.boundary_token_fn( + input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" + ) d["inputs"] = sentence_to_text_field(inp, indexers) d["sent2_str"] = MetadataField(" ".join(input2)) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) if input2: d["input2"] = sentence_to_text_field( @@ -2191,7 +2211,10 @@ def process_split( def _make_instance(input1, input2, labels): d = {} d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -2291,12 +2314,15 @@ def _make_instance(input1, input2, labels): d = {} if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp = model_preprocessing_interface.boundary_token_fn( - input1, input2 + input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" ) # drop leading [CLS] token d["inputs"] = sentence_to_text_field(inp, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -2435,7 +2461,10 @@ def process_split( def _make_instance(input1, input2, target, mask): d = {} d["inputs"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(input1), indexers + model_preprocessing_interface.boundary_token_fn( + input1, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) d["sent1_str"] = MetadataField(" ".join(input1)) d["targs"] = sentence_to_text_field(target, self.target_indexer) @@ -2617,7 +2646,9 @@ def _make_span_field(self, s, text_field, offset=1): def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() - tokens, offset = model_preprocessing_interface.boundary_token_fn(tokens, get_offset=True) + tokens, offset = model_preprocessing_interface.boundary_token_fn( + tokens, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True + ) text_field = sentence_to_text_field(tokens, indexers) example = {} @@ -2835,12 +2866,16 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx): d["sent2_str"] = MetadataField(" ".join(input2)) if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn( - input1, input2, get_offset=True + input1, + s2=input2, + trunc_strategy="trunc_s1", + trunc_side="right", + get_offset=True, ) d["inputs"] = sentence_to_text_field(inp[: self.max_seq_len], indexers) else: inp1, offset1 = model_preprocessing_interface.boundary_token_fn( - input1, get_offset=True + input1, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True ) inp2, offset2 = model_preprocessing_interface.boundary_token_fn( input2, get_offset=True @@ -2968,11 +3003,16 @@ def _make_instance(context, choices, question, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(context), indexers + model_preprocessing_interface.boundary_token_fn( + context, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(context, question + choice) + model_preprocessing_interface.boundary_token_fn( + context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3131,11 +3171,16 @@ def _make_instance(context, choices, question, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(context), indexers + model_preprocessing_interface.boundary_token_fn( + context, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(context, question + choice) + model_preprocessing_interface.boundary_token_fn( + context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3217,11 +3262,16 @@ def _make_instance(question, choices, label, idx): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(question), indexers + model_preprocessing_interface.boundary_token_fn( + question, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(question, choice) + model_preprocessing_interface.boundary_token_fn( + question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3308,11 +3358,16 @@ def _make_instance(question, choices, label, idx): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(question), indexers + model_preprocessing_interface.boundary_token_fn( + question, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(question, choice) + model_preprocessing_interface.boundary_token_fn( + question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3431,14 +3486,17 @@ def _make_instance(d, idx): new_d["passage_str"] = MetadataField(" ".join(d["passage"])) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: new_d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(d["passage"]), indexers + model_preprocessing_interface.boundary_token_fn( + d["passage"], trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) new_d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(d["question"]), indexers ) else: # BERT/XLNet psg_qst = model_preprocessing_interface.boundary_token_fn( - d["passage"], d["question"] + d["passage"], s2=d["question"], trunc_strategy="trunc_s1", trunc_side="right" ) new_d["inputs"] = sentence_to_text_field(psg_qst, indexers) new_d["labels"] = LabelField(d["label"], label_namespace="labels", skip_indexing=True) @@ -3561,7 +3619,7 @@ def _make_instance(obs1, hyp1, hyp2, obs2, label, idx): else: for hyp_idx, hyp in enumerate([hyp1, hyp2]): inp = ( - model_preprocessing_interface.boundary_token_fn(obs1 + hyp, obs2) + model_preprocessing_interface.boundary_token_fn(obs1 + hyp, s2=obs2) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(hyp) ) @@ -3709,11 +3767,16 @@ def _make_instance(context, choices, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn(context), indexers + model_preprocessing_interface.boundary_token_fn( + context, trunc_strategy="trunc_s1", trunc_side="right" + ), + indexers, ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn(context, choice) + model_preprocessing_interface.boundary_token_fn( + context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" + ) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 3dca1174c..34c87a428 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -98,7 +98,7 @@ def select_pool_type(args): return pool_type -def apply_standard_boundary_tokens(s1, s2=None): +def apply_standard_boundary_tokens(s1, *, s2=None): """Apply and to sequences of string-valued tokens. Corresponds to more complex functions used with models like XLNet and BERT. """ diff --git a/tests/test_huggingface_transformers_interface.py b/tests/test_huggingface_transformers_interface.py index de5e0dc7e..bb6eabdde 100644 --- a/tests/test_huggingface_transformers_interface.py +++ b/tests/test_huggingface_transformers_interface.py @@ -23,7 +23,7 @@ def test_bert_apply_boundary_tokens(self): BertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] ) self.assertListEqual( - BertEmbedderModule.apply_boundary_tokens(s1, s2), + BertEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], ) @@ -34,7 +34,7 @@ def test_roberta_apply_boundary_tokens(self): RobertaEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""] ) self.assertListEqual( - RobertaEmbedderModule.apply_boundary_tokens(s1, s2), + RobertaEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["", "A", "B", "C", "", "", "D", "E", ""], ) @@ -45,7 +45,7 @@ def test_albert_apply_boundary_tokens(self): AlbertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] ) self.assertListEqual( - AlbertEmbedderModule.apply_boundary_tokens(s1, s2), + AlbertEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], ) @@ -56,7 +56,7 @@ def test_xlnet_apply_boundary_tokens(self): XLNetEmbedderModule.apply_boundary_tokens(s1), ["A", "B", "C", "", ""] ) self.assertListEqual( - XLNetEmbedderModule.apply_boundary_tokens(s1, s2), + XLNetEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["A", "B", "C", "", "D", "E", "", ""], ) @@ -68,7 +68,7 @@ def test_gpt_apply_boundary_tokens(self): ["", "A", "B", "C", ""], ) self.assertListEqual( - OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2), + OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["", "A", "B", "C", "", "D", "E", ""], ) @@ -79,7 +79,7 @@ def test_xlm_apply_boundary_tokens(self): XLMEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""] ) self.assertListEqual( - XLMEmbedderModule.apply_boundary_tokens(s1, s2), + XLMEmbedderModule.apply_boundary_tokens(s1, s2=s2), ["", "A", "B", "C", "", "D", "E", ""], ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index ef95a8031..2e2c6bda1 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -8,8 +8,14 @@ from unittest import mock import jiant.tasks.tasks as tasks +from jiant.utils import config from jiant.utils.config import params_from_file -from jiant.preprocess import get_task_without_loading_data, build_indexers, get_vocab +from jiant.preprocess import ( + get_task_without_loading_data, + build_indexers, + get_vocab, + ModelPreprocessingInterface, +) class TestProprocess(unittest.TestCase): @@ -75,3 +81,118 @@ def test_build_vocab(self): assert set(vocab.get_index_to_token_vocabulary("chars").values()) == set( ["@@PADDING@@", "@@UNKNOWN@@", "a", "b", "c"] ) + + +class TestModelPreprocessingInterface(unittest.TestCase): + def test_max_tokens_limited_by_small_max_seq_len(self): + args = config.Params(max_seq_len=7, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + self.assertEqual(mpi.max_tokens, 7) + + def test_max_tokens_limited_by_bert_model_max(self): + args = config.Params(max_seq_len=None, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + self.assertEqual(mpi.max_tokens, 512) + + def test_boundary_token_fn_trunc_w_default_strategies(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_s2", + trunc_side="right", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"]) + + def test_boundary_token_fn_trunc_s2_left_side(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_s2", + trunc_side="left", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "zoo", "[SEP]"]) + + def test_boundary_token_fn_trunc_s2_right_side(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_s2", + trunc_side="right", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"]) + + def test_boundary_token_fn_trunc_s1_left_side(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_s1", + trunc_side="left", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"]) + + def test_boundary_token_fn_trunc_s1_right_side(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_s1", + trunc_side="right", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "Apple", "[SEP]", "Xray", "you", "zoo", "[SEP]"]) + + def test_boundary_token_fn_trunc_both(self): + MAX_SEQ_LEN = 7 + args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn( + ["Apple", "buy", "call"], + s2=["Xray", "you", "zoo"], + trunc_strategy="trunc_both", + trunc_side="left", + ) + self.assertEqual(len(seq), MAX_SEQ_LEN) + self.assertEqual(seq, ["[CLS]", "buy", "call", "[SEP]", "you", "zoo", "[SEP]"]) + + def test_boundary_token_fn_trunc_with_short_sequence_and_no_max_seq_len(self): + args = config.Params(max_seq_len=None, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + seq = mpi.boundary_token_fn(["Apple", "buy", "call"], s2=["Xray", "you", "zoo"]) + self.assertEqual( + seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"] + ) + + def test_boundary_token_fn_throws_exception_when_trunc_is_needed_but_strat_unspecified(self): + args = config.Params(max_seq_len=1, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + self.assertRaises(ValueError, mpi.boundary_token_fn, ["Apple", "buy", "call"], s2=["Xray"]) + + def test_boundary_token_fn_throws_exception_when_trunc_beyond_input_length(self): + args = config.Params(max_seq_len=1, input_module="bert-base-uncased") + mpi = ModelPreprocessingInterface(args) + self.assertRaises( + AssertionError, + mpi.boundary_token_fn, + ["Apple", "buy", "call"], + s2=["Xray"], + trunc_strategy="trunc_s2", + trunc_side="right", + )