From 5d7a1c853e32a9cdcc51609a64ba557db1cae6c7 Mon Sep 17 00:00:00 2001 From: Phil Yeres Date: Tue, 25 Feb 2020 14:31:40 -0500 Subject: [PATCH] Revert "Address max sequence length issue (#1002) (#1006)" (#1013) --- .../modules.py | 18 +- jiant/preprocess.py | 155 +----------------- jiant/tasks/edge_probing.py | 2 +- jiant/tasks/qa.py | 58 ++----- jiant/tasks/tasks.py | 119 ++++---------- jiant/utils/utils.py | 2 +- ...test_huggingface_transformers_interface.py | 12 +- tests/test_preprocess.py | 123 +------------- 8 files changed, 61 insertions(+), 428 deletions(-) diff --git a/jiant/huggingface_transformers_interface/modules.py b/jiant/huggingface_transformers_interface/modules.py index 8b10998e7..d8957e0cd 100644 --- a/jiant/huggingface_transformers_interface/modules.py +++ b/jiant/huggingface_transformers_interface/modules.py @@ -189,7 +189,7 @@ def get_seg_ids(self, token_ids, input_mask): return seg_ids @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): """ A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or token sequence pair for most tasks. @@ -274,7 +274,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # BERT-style boundary token padding on string token sequences if s2: s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] @@ -331,7 +331,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # RoBERTa-style boundary token padding on string token sequences if s2: s = [""] + s1 + ["", ""] + s2 + [""] @@ -385,7 +385,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # ALBERT-style boundary token padding on string token sequences if s2: s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"] @@ -448,7 +448,7 @@ def __init__(self, args): self._SEG_ID_SEP = 3 @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # XLNet-style boundary token marking on string token sequences if s2: s = s1 + [""] + s2 + ["", ""] @@ -506,7 +506,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # OpenAI-GPT-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -568,7 +568,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # GPT-2-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -630,7 +630,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # TransformerXL-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] @@ -697,7 +697,7 @@ def __init__(self, args): self.parameter_setup(args) @staticmethod - def apply_boundary_tokens(s1, *, s2=None, get_offset=False): + def apply_boundary_tokens(s1, s2=None, get_offset=False): # XLM-style boundary token marking on string token sequences if s2: s = [""] + s1 + [""] + s2 + [""] diff --git a/jiant/preprocess.py b/jiant/preprocess.py index 815c383d6..27e3b6e18 100644 --- a/jiant/preprocess.py +++ b/jiant/preprocess.py @@ -740,202 +740,53 @@ class ModelPreprocessingInterface(object): """ - @staticmethod - def _apply_boundary_tokens(apply_boundary_tokens, max_tokens): - """ - Takes a function that applies boundary tokens and modifies it to respect a max_seq_len arg. - - Parameters - ---------- - apply_boundary_tokens : function - Takes a function that adds boundary tokens and implements the common boundary token - adding function interface demonstated in HuggingfaceTransformersEmbedderModule. - max_tokens : int - The maximum number of tokens to allow in the output sequence. - - Returns - ------- - apply_boundary_tokens_with_trunc_strategy : function - Function w/ the common boundary token adding interface demonstrated in - HuggingfaceTransformersEmbedderModule, but also implementing a truncation strategy. - """ - - def _apply_boundary_tokens_with_trunc_strategy( - *args, trunc_strategy=None, trunc_side=None, **kwargs - ): - """ - Calls the apply_boundary_tokens function provided in the parent function and, if the - output exceeds the max number of tokens, applies a truncation strategy to the inputs, - then re-applies the apply_boundary_tokens function on the truncated inputs and returns - the result. - - Parameters - ---------- - args : tuple - see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in - huggingface_transformers_interface.modules - trunc_strategy : str - Which strings to truncate. Options: - `trunc_s1`: select s1 for truncation. - `trunc_s2`: select s2 for truncation. - `trunc_both`: truncate s1 and s2 equally. - trunc_side : str - Options are `right` or `left`. Indicates which side of the sequence to remove from. - e.g., if `left`, then the first element to be removed from ["A", "B", "C"] is "A". - kwargs : dict - see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in - huggingface_transformers_interface.modules - - Returns - ------- - seq_w_boundry_tokens : List[str] - List of tokens returned by applying the specified apply_boundary_tokens function - to the inputs and using the specified truncation strategy. - - Raises - ------ - ValueError - If sequence requires truncation but trunc_strategy or trunc_side are unspecified - or invalid. - - """ - seq_w_boundry_tokens = apply_boundary_tokens(*args, **kwargs) - # if after calling the apply_boundary_tokens function the number of tokens is greater - # than the model's max, a truncation strategy can reduce the number of tokens: - num_excess_tokens = len(seq_w_boundry_tokens) - max_tokens - if num_excess_tokens > 0: - if not (trunc_strategy and trunc_side): - raise ValueError( - "Input(s) length will exceed model capacity or max_seq_len. " - + "Adjust settings as necessary, or, to automatically truncate " - + "inputs in `apply_boundary_tokens`, call `apply_boundary_tokens` " - + "with `trunc_strategy` and `trunc_side` keyword args." - ) - if trunc_strategy == "trunc_s2": - s2 = kwargs["s2"] - if trunc_side == "right": - s2_truncated = s2[:-num_excess_tokens] - elif trunc_side == "left": - s2_truncated = s2[num_excess_tokens:] - log.info( - "Before truncation s2 length = " - + str(len(s2)) - + ", after truncation s2 length = " - + str(len(s2_truncated)) - ) - assert len(s2_truncated) > 0, "After truncation, s2 length would be 0." - args = list(args) - kwargs["s2"] = s2_truncated - return apply_boundary_tokens(*args, **kwargs) - elif trunc_strategy == "trunc_both": - s1 = args[0] - s2 = kwargs["s2"] - if trunc_side == "right": - s1_truncated = s1[: -num_excess_tokens // 2] - s2_truncated = s2[: -num_excess_tokens // 2] - elif trunc_side == "left": - s1_truncated = s1[num_excess_tokens // 2 :] - s2_truncated = s2[num_excess_tokens // 2 :] - log.info( - "Before truncation s1 length = " - + str(len(s1)) - + " and s2 length = " - + str(len(s2)) - + ". After truncation s1 length = " - + str(len(s1_truncated)) - + " and s2 length = " - + str(len(s2_truncated)) - ) - assert len(s1_truncated) > 0, "After truncation, s1 length would be 0." - assert len(s2_truncated) > 0, "After truncation, s2 length would be 0." - args = list(args) - args[0] = s1_truncated - kwargs["s2"] = s2_truncated - return apply_boundary_tokens(*args, **kwargs) - elif trunc_strategy == "trunc_s1": - s1 = args[0] - if trunc_side == "right": - s1_truncated = s1[:-num_excess_tokens] - elif trunc_side == "left": - s1_truncated = s1[num_excess_tokens:] - log.info( - "Before truncation s1 length = " - + str(len(s1)) - + ", after truncation s1 length = " - + str(len(s1_truncated)) - ) - assert len(s1_truncated) > 0, "After truncation, s1 length would be 0." - args = list(args) - args[0] = s1_truncated - return apply_boundary_tokens(*args, **kwargs) - else: - raise ValueError(trunc_strategy + " is not a valid truncation strategy.") - else: - return seq_w_boundry_tokens - - return _apply_boundary_tokens_with_trunc_strategy - def __init__(self, args): boundary_token_fn = None lm_boundary_token_fn = None - max_pos: int = None if args.input_module.startswith("bert-"): from jiant.huggingface_transformers_interface.modules import BertEmbedderModule boundary_token_fn = BertEmbedderModule.apply_boundary_tokens - max_pos = BertTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("roberta-"): from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens - max_pos = RobertaTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("albert-"): from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens - max_pos = AlbertTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("xlnet-"): from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens - max_pos = XLNetTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("openai-gpt"): from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens - max_pos = OpenAIGPTTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("gpt2"): from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens - max_pos = GPT2Tokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("transfo-xl-"): from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens - max_pos = TransfoXLTokenizer.max_model_input_sizes.get(args.input_module, None) elif args.input_module.startswith("xlm-"): from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens - max_pos = XLMTokenizer.max_model_input_sizes.get(args.input_module, None) else: boundary_token_fn = utils.apply_standard_boundary_tokens - self.max_tokens = min(x for x in [max_pos, args.max_seq_len] if x is not None) - - self.boundary_token_fn = self._apply_boundary_tokens(boundary_token_fn, self.max_tokens) - + self.boundary_token_fn = boundary_token_fn if lm_boundary_token_fn is not None: - self.lm_boundary_token_fn = self._apply_boundary_tokens( - lm_boundary_token_fn, args.max_seq_len - ) + self.lm_boundary_token_fn = lm_boundary_token_fn else: - self.lm_boundary_token_fn = self.boundary_token_fn + self.lm_boundary_token_fn = boundary_token_fn from jiant.models import input_module_uses_pair_embedding, input_module_uses_mirrored_pair diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py index b6607382b..6821441ef 100644 --- a/jiant/tasks/edge_probing.py +++ b/jiant/tasks/edge_probing.py @@ -182,7 +182,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() # already space-tokenized by Moses tokens = model_preprocessing_interface.boundary_token_fn( - tokens, trunc_strategy="trunc_s1", trunc_side="right" + tokens ) # apply model-appropriate variants of [cls] and [sep]. text_field = sentence_to_text_field(tokens, indexers) diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py index 7b2492908..24da6d68f 100644 --- a/jiant/tasks/qa.py +++ b/jiant/tasks/qa.py @@ -120,16 +120,11 @@ def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx): d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn( - para, s2=question + answer, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp = model_preprocessing_interface.boundary_token_fn(para, question + answer) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: d["psg"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - passage, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(passage), indexers ) d["qst"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(question), indexers @@ -313,16 +308,11 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx): d["ans_idx"] = MetadataField(ans_idx) d["idx"] = MetadataField(ans_idx) # required by evaluate() if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn( - psg, s2=qst, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp = model_preprocessing_interface.boundary_token_fn(psg, qst) d["psg_qst_ans"] = sentence_to_text_field(inp, indexers) else: d["psg"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - psg, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(psg), indexers ) d["qst"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(qst), indexers @@ -468,19 +458,12 @@ def _make_instance(example): if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn( - example["passage"], - s2=example["question"], - trunc_strategy="trunc_s1", - trunc_side="right", - get_offset=True, + example["passage"], example["question"], get_offset=True ) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["passage"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - example["passage"], trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers ) d["question"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(example["question"]), indexers @@ -631,19 +614,12 @@ def _make_instance(example): if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn( - example["passage"], - s2=example["question"], - trunc_strategy="trunc_s1", - trunc_side="right", - get_offset=True, + example["passage"], example["question"], get_offset=True ) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["passage"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - example["passage"], trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers ) d["question"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(example["question"]), indexers @@ -909,16 +885,11 @@ def _make_instance(question, choices, label, id_str): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - question, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(question), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(question, choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -1012,16 +983,11 @@ def _make_instance(context, choices, label, id_str): d["context_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["context"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - context, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(context), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(context, choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py index 512a0b37d..a21a6cb34 100644 --- a/jiant/tasks/tasks.py +++ b/jiant/tasks/tasks.py @@ -108,25 +108,18 @@ def _make_instance(input1, input2, labels, idx): d = {} d["sent1_str"] = MetadataField(" ".join(input1)) if model_preprocessing_interface.model_flags["uses_pair_embedding"] and is_pair: - inp = model_preprocessing_interface.boundary_token_fn( - input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp = model_preprocessing_interface.boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) d["sent2_str"] = MetadataField(" ".join(input2)) if ( model_preprocessing_interface.model_flags["uses_mirrored_pair"] and is_symmetrical_pair ): - inp_m = model_preprocessing_interface.boundary_token_fn( - input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp_m = model_preprocessing_interface.boundary_token_fn(input1, input2) d["inputs_m"] = sentence_to_text_field(inp_m, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) if input2: d["input2"] = sentence_to_text_field( @@ -795,10 +788,7 @@ def _make_instance(input1, labels, tagids): """ from multiple types in one column create multiple fields """ d = {} d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) d["sent1_str"] = MetadataField(" ".join(input1)) d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True) @@ -1631,16 +1621,11 @@ def _make_instance(input1, input2, label, idx, lex_sem, pr_ar_str, logic, knowle """ from multiple types in one column create multiple fields """ d = {} if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn( - input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp = model_preprocessing_interface.boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -1844,17 +1829,12 @@ def _make_instance(input1, input2, labels, idx, pair_id): d = {} d["sent1_str"] = MetadataField(" ".join(input1)) if model_preprocessing_interface.model_flags["uses_pair_embedding"]: - inp = model_preprocessing_interface.boundary_token_fn( - input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" - ) + inp = model_preprocessing_interface.boundary_token_fn(input1, input2) d["inputs"] = sentence_to_text_field(inp, indexers) d["sent2_str"] = MetadataField(" ".join(input2)) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) if input2: d["input2"] = sentence_to_text_field( @@ -2211,10 +2191,7 @@ def process_split( def _make_instance(input1, input2, labels): d = {} d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -2314,15 +2291,12 @@ def _make_instance(input1, input2, labels): d = {} if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp = model_preprocessing_interface.boundary_token_fn( - input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right" + input1, input2 ) # drop leading [CLS] token d["inputs"] = sentence_to_text_field(inp, indexers) else: d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(input2), indexers @@ -2461,10 +2435,7 @@ def process_split( def _make_instance(input1, input2, target, mask): d = {} d["inputs"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(input1), indexers ) d["sent1_str"] = MetadataField(" ".join(input1)) d["targs"] = sentence_to_text_field(target, self.target_indexer) @@ -2646,9 +2617,7 @@ def _make_span_field(self, s, text_field, offset=1): def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]: """Convert a single record to an AllenNLP Instance.""" tokens = record["text"].split() - tokens, offset = model_preprocessing_interface.boundary_token_fn( - tokens, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True - ) + tokens, offset = model_preprocessing_interface.boundary_token_fn(tokens, get_offset=True) text_field = sentence_to_text_field(tokens, indexers) example = {} @@ -2866,16 +2835,12 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx): d["sent2_str"] = MetadataField(" ".join(input2)) if model_preprocessing_interface.model_flags["uses_pair_embedding"]: inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn( - input1, - s2=input2, - trunc_strategy="trunc_s1", - trunc_side="right", - get_offset=True, + input1, input2, get_offset=True ) d["inputs"] = sentence_to_text_field(inp[: self.max_seq_len], indexers) else: inp1, offset1 = model_preprocessing_interface.boundary_token_fn( - input1, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True + input1, get_offset=True ) inp2, offset2 = model_preprocessing_interface.boundary_token_fn( input2, get_offset=True @@ -3003,16 +2968,11 @@ def _make_instance(context, choices, question, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - context, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(context), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(context, question + choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3171,16 +3131,11 @@ def _make_instance(context, choices, question, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - context, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(context), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(context, question + choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3262,16 +3217,11 @@ def _make_instance(question, choices, label, idx): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - question, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(question), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(question, choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3358,16 +3308,11 @@ def _make_instance(question, choices, label, idx): d["question_str"] = MetadataField(" ".join(question)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - question, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(question), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(question, choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) @@ -3486,17 +3431,14 @@ def _make_instance(d, idx): new_d["passage_str"] = MetadataField(" ".join(d["passage"])) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: new_d["input1"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - d["passage"], trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(d["passage"]), indexers ) new_d["input2"] = sentence_to_text_field( model_preprocessing_interface.boundary_token_fn(d["question"]), indexers ) else: # BERT/XLNet psg_qst = model_preprocessing_interface.boundary_token_fn( - d["passage"], s2=d["question"], trunc_strategy="trunc_s1", trunc_side="right" + d["passage"], d["question"] ) new_d["inputs"] = sentence_to_text_field(psg_qst, indexers) new_d["labels"] = LabelField(d["label"], label_namespace="labels", skip_indexing=True) @@ -3619,7 +3561,7 @@ def _make_instance(obs1, hyp1, hyp2, obs2, label, idx): else: for hyp_idx, hyp in enumerate([hyp1, hyp2]): inp = ( - model_preprocessing_interface.boundary_token_fn(obs1 + hyp, s2=obs2) + model_preprocessing_interface.boundary_token_fn(obs1 + hyp, obs2) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(hyp) ) @@ -3767,16 +3709,11 @@ def _make_instance(context, choices, label, idx): d["question_str"] = MetadataField(" ".join(context)) if not model_preprocessing_interface.model_flags["uses_pair_embedding"]: d["question"] = sentence_to_text_field( - model_preprocessing_interface.boundary_token_fn( - context, trunc_strategy="trunc_s1", trunc_side="right" - ), - indexers, + model_preprocessing_interface.boundary_token_fn(context), indexers ) for choice_idx, choice in enumerate(choices): inp = ( - model_preprocessing_interface.boundary_token_fn( - context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right" - ) + model_preprocessing_interface.boundary_token_fn(context, choice) if model_preprocessing_interface.model_flags["uses_pair_embedding"] else model_preprocessing_interface.boundary_token_fn(choice) ) diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py index 34c87a428..3dca1174c 100644 --- a/jiant/utils/utils.py +++ b/jiant/utils/utils.py @@ -98,7 +98,7 @@ def select_pool_type(args): return pool_type -def apply_standard_boundary_tokens(s1, *, s2=None): +def apply_standard_boundary_tokens(s1, s2=None): """Apply and to sequences of string-valued tokens. Corresponds to more complex functions used with models like XLNet and BERT. """ diff --git a/tests/test_huggingface_transformers_interface.py b/tests/test_huggingface_transformers_interface.py index bb6eabdde..de5e0dc7e 100644 --- a/tests/test_huggingface_transformers_interface.py +++ b/tests/test_huggingface_transformers_interface.py @@ -23,7 +23,7 @@ def test_bert_apply_boundary_tokens(self): BertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] ) self.assertListEqual( - BertEmbedderModule.apply_boundary_tokens(s1, s2=s2), + BertEmbedderModule.apply_boundary_tokens(s1, s2), ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], ) @@ -34,7 +34,7 @@ def test_roberta_apply_boundary_tokens(self): RobertaEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""] ) self.assertListEqual( - RobertaEmbedderModule.apply_boundary_tokens(s1, s2=s2), + RobertaEmbedderModule.apply_boundary_tokens(s1, s2), ["", "A", "B", "C", "", "", "D", "E", ""], ) @@ -45,7 +45,7 @@ def test_albert_apply_boundary_tokens(self): AlbertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"] ) self.assertListEqual( - AlbertEmbedderModule.apply_boundary_tokens(s1, s2=s2), + AlbertEmbedderModule.apply_boundary_tokens(s1, s2), ["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"], ) @@ -56,7 +56,7 @@ def test_xlnet_apply_boundary_tokens(self): XLNetEmbedderModule.apply_boundary_tokens(s1), ["A", "B", "C", "", ""] ) self.assertListEqual( - XLNetEmbedderModule.apply_boundary_tokens(s1, s2=s2), + XLNetEmbedderModule.apply_boundary_tokens(s1, s2), ["A", "B", "C", "", "D", "E", "", ""], ) @@ -68,7 +68,7 @@ def test_gpt_apply_boundary_tokens(self): ["", "A", "B", "C", ""], ) self.assertListEqual( - OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2=s2), + OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2), ["", "A", "B", "C", "", "D", "E", ""], ) @@ -79,7 +79,7 @@ def test_xlm_apply_boundary_tokens(self): XLMEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""] ) self.assertListEqual( - XLMEmbedderModule.apply_boundary_tokens(s1, s2=s2), + XLMEmbedderModule.apply_boundary_tokens(s1, s2), ["", "A", "B", "C", "", "D", "E", ""], ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 2e2c6bda1..ef95a8031 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -8,14 +8,8 @@ from unittest import mock import jiant.tasks.tasks as tasks -from jiant.utils import config from jiant.utils.config import params_from_file -from jiant.preprocess import ( - get_task_without_loading_data, - build_indexers, - get_vocab, - ModelPreprocessingInterface, -) +from jiant.preprocess import get_task_without_loading_data, build_indexers, get_vocab class TestProprocess(unittest.TestCase): @@ -81,118 +75,3 @@ def test_build_vocab(self): assert set(vocab.get_index_to_token_vocabulary("chars").values()) == set( ["@@PADDING@@", "@@UNKNOWN@@", "a", "b", "c"] ) - - -class TestModelPreprocessingInterface(unittest.TestCase): - def test_max_tokens_limited_by_small_max_seq_len(self): - args = config.Params(max_seq_len=7, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - self.assertEqual(mpi.max_tokens, 7) - - def test_max_tokens_limited_by_bert_model_max(self): - args = config.Params(max_seq_len=None, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - self.assertEqual(mpi.max_tokens, 512) - - def test_boundary_token_fn_trunc_w_default_strategies(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_s2", - trunc_side="right", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"]) - - def test_boundary_token_fn_trunc_s2_left_side(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_s2", - trunc_side="left", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "zoo", "[SEP]"]) - - def test_boundary_token_fn_trunc_s2_right_side(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_s2", - trunc_side="right", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"]) - - def test_boundary_token_fn_trunc_s1_left_side(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_s1", - trunc_side="left", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"]) - - def test_boundary_token_fn_trunc_s1_right_side(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_s1", - trunc_side="right", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "Apple", "[SEP]", "Xray", "you", "zoo", "[SEP]"]) - - def test_boundary_token_fn_trunc_both(self): - MAX_SEQ_LEN = 7 - args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn( - ["Apple", "buy", "call"], - s2=["Xray", "you", "zoo"], - trunc_strategy="trunc_both", - trunc_side="left", - ) - self.assertEqual(len(seq), MAX_SEQ_LEN) - self.assertEqual(seq, ["[CLS]", "buy", "call", "[SEP]", "you", "zoo", "[SEP]"]) - - def test_boundary_token_fn_trunc_with_short_sequence_and_no_max_seq_len(self): - args = config.Params(max_seq_len=None, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - seq = mpi.boundary_token_fn(["Apple", "buy", "call"], s2=["Xray", "you", "zoo"]) - self.assertEqual( - seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"] - ) - - def test_boundary_token_fn_throws_exception_when_trunc_is_needed_but_strat_unspecified(self): - args = config.Params(max_seq_len=1, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - self.assertRaises(ValueError, mpi.boundary_token_fn, ["Apple", "buy", "call"], s2=["Xray"]) - - def test_boundary_token_fn_throws_exception_when_trunc_beyond_input_length(self): - args = config.Params(max_seq_len=1, input_module="bert-base-uncased") - mpi = ModelPreprocessingInterface(args) - self.assertRaises( - AssertionError, - mpi.boundary_token_fn, - ["Apple", "buy", "call"], - s2=["Xray"], - trunc_strategy="trunc_s2", - trunc_side="right", - )