diff --git a/jiant/huggingface_transformers_interface/modules.py b/jiant/huggingface_transformers_interface/modules.py
index d8957e0cd..8b10998e7 100644
--- a/jiant/huggingface_transformers_interface/modules.py
+++ b/jiant/huggingface_transformers_interface/modules.py
@@ -189,7 +189,7 @@ def get_seg_ids(self, token_ids, input_mask):
return seg_ids
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
"""
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or
token sequence pair for most tasks.
@@ -274,7 +274,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# BERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
@@ -331,7 +331,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# RoBERTa-style boundary token padding on string token sequences
if s2:
s = [""] + s1 + ["", ""] + s2 + [""]
@@ -385,7 +385,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# ALBERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
@@ -448,7 +448,7 @@ def __init__(self, args):
self._SEG_ID_SEP = 3
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# XLNet-style boundary token marking on string token sequences
if s2:
s = s1 + [""] + s2 + ["", ""]
@@ -506,7 +506,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# OpenAI-GPT-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -568,7 +568,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# GPT-2-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -630,7 +630,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# TransformerXL-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -697,7 +697,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
# XLM-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
diff --git a/jiant/preprocess.py b/jiant/preprocess.py
index 27e3b6e18..815c383d6 100644
--- a/jiant/preprocess.py
+++ b/jiant/preprocess.py
@@ -740,53 +740,202 @@ class ModelPreprocessingInterface(object):
"""
+ @staticmethod
+ def _apply_boundary_tokens(apply_boundary_tokens, max_tokens):
+ """
+ Takes a function that applies boundary tokens and modifies it to respect a max_seq_len arg.
+
+ Parameters
+ ----------
+ apply_boundary_tokens : function
+ Takes a function that adds boundary tokens and implements the common boundary token
+ adding function interface demonstated in HuggingfaceTransformersEmbedderModule.
+ max_tokens : int
+ The maximum number of tokens to allow in the output sequence.
+
+ Returns
+ -------
+ apply_boundary_tokens_with_trunc_strategy : function
+ Function w/ the common boundary token adding interface demonstrated in
+ HuggingfaceTransformersEmbedderModule, but also implementing a truncation strategy.
+ """
+
+ def _apply_boundary_tokens_with_trunc_strategy(
+ *args, trunc_strategy=None, trunc_side=None, **kwargs
+ ):
+ """
+ Calls the apply_boundary_tokens function provided in the parent function and, if the
+ output exceeds the max number of tokens, applies a truncation strategy to the inputs,
+ then re-applies the apply_boundary_tokens function on the truncated inputs and returns
+ the result.
+
+ Parameters
+ ----------
+ args : tuple
+ see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
+ huggingface_transformers_interface.modules
+ trunc_strategy : str
+ Which strings to truncate. Options:
+ `trunc_s1`: select s1 for truncation.
+ `trunc_s2`: select s2 for truncation.
+ `trunc_both`: truncate s1 and s2 equally.
+ trunc_side : str
+ Options are `right` or `left`. Indicates which side of the sequence to remove from.
+ e.g., if `left`, then the first element to be removed from ["A", "B", "C"] is "A".
+ kwargs : dict
+ see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
+ huggingface_transformers_interface.modules
+
+ Returns
+ -------
+ seq_w_boundry_tokens : List[str]
+ List of tokens returned by applying the specified apply_boundary_tokens function
+ to the inputs and using the specified truncation strategy.
+
+ Raises
+ ------
+ ValueError
+ If sequence requires truncation but trunc_strategy or trunc_side are unspecified
+ or invalid.
+
+ """
+ seq_w_boundry_tokens = apply_boundary_tokens(*args, **kwargs)
+ # if after calling the apply_boundary_tokens function the number of tokens is greater
+ # than the model's max, a truncation strategy can reduce the number of tokens:
+ num_excess_tokens = len(seq_w_boundry_tokens) - max_tokens
+ if num_excess_tokens > 0:
+ if not (trunc_strategy and trunc_side):
+ raise ValueError(
+ "Input(s) length will exceed model capacity or max_seq_len. "
+ + "Adjust settings as necessary, or, to automatically truncate "
+ + "inputs in `apply_boundary_tokens`, call `apply_boundary_tokens` "
+ + "with `trunc_strategy` and `trunc_side` keyword args."
+ )
+ if trunc_strategy == "trunc_s2":
+ s2 = kwargs["s2"]
+ if trunc_side == "right":
+ s2_truncated = s2[:-num_excess_tokens]
+ elif trunc_side == "left":
+ s2_truncated = s2[num_excess_tokens:]
+ log.info(
+ "Before truncation s2 length = "
+ + str(len(s2))
+ + ", after truncation s2 length = "
+ + str(len(s2_truncated))
+ )
+ assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
+ args = list(args)
+ kwargs["s2"] = s2_truncated
+ return apply_boundary_tokens(*args, **kwargs)
+ elif trunc_strategy == "trunc_both":
+ s1 = args[0]
+ s2 = kwargs["s2"]
+ if trunc_side == "right":
+ s1_truncated = s1[: -num_excess_tokens // 2]
+ s2_truncated = s2[: -num_excess_tokens // 2]
+ elif trunc_side == "left":
+ s1_truncated = s1[num_excess_tokens // 2 :]
+ s2_truncated = s2[num_excess_tokens // 2 :]
+ log.info(
+ "Before truncation s1 length = "
+ + str(len(s1))
+ + " and s2 length = "
+ + str(len(s2))
+ + ". After truncation s1 length = "
+ + str(len(s1_truncated))
+ + " and s2 length = "
+ + str(len(s2_truncated))
+ )
+ assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
+ assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
+ args = list(args)
+ args[0] = s1_truncated
+ kwargs["s2"] = s2_truncated
+ return apply_boundary_tokens(*args, **kwargs)
+ elif trunc_strategy == "trunc_s1":
+ s1 = args[0]
+ if trunc_side == "right":
+ s1_truncated = s1[:-num_excess_tokens]
+ elif trunc_side == "left":
+ s1_truncated = s1[num_excess_tokens:]
+ log.info(
+ "Before truncation s1 length = "
+ + str(len(s1))
+ + ", after truncation s1 length = "
+ + str(len(s1_truncated))
+ )
+ assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
+ args = list(args)
+ args[0] = s1_truncated
+ return apply_boundary_tokens(*args, **kwargs)
+ else:
+ raise ValueError(trunc_strategy + " is not a valid truncation strategy.")
+ else:
+ return seq_w_boundry_tokens
+
+ return _apply_boundary_tokens_with_trunc_strategy
+
def __init__(self, args):
boundary_token_fn = None
lm_boundary_token_fn = None
+ max_pos: int = None
if args.input_module.startswith("bert-"):
from jiant.huggingface_transformers_interface.modules import BertEmbedderModule
boundary_token_fn = BertEmbedderModule.apply_boundary_tokens
+ max_pos = BertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("roberta-"):
from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule
boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens
+ max_pos = RobertaTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("albert-"):
from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule
boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens
+ max_pos = AlbertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlnet-"):
from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule
boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens
+ max_pos = XLNetTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("openai-gpt"):
from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule
boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens
+ max_pos = OpenAIGPTTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("gpt2"):
from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule
boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens
+ max_pos = GPT2Tokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("transfo-xl-"):
from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule
boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens
+ max_pos = TransfoXLTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlm-"):
from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule
boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens
+ max_pos = XLMTokenizer.max_model_input_sizes.get(args.input_module, None)
else:
boundary_token_fn = utils.apply_standard_boundary_tokens
- self.boundary_token_fn = boundary_token_fn
+ self.max_tokens = min(x for x in [max_pos, args.max_seq_len] if x is not None)
+
+ self.boundary_token_fn = self._apply_boundary_tokens(boundary_token_fn, self.max_tokens)
+
if lm_boundary_token_fn is not None:
- self.lm_boundary_token_fn = lm_boundary_token_fn
+ self.lm_boundary_token_fn = self._apply_boundary_tokens(
+ lm_boundary_token_fn, args.max_seq_len
+ )
else:
- self.lm_boundary_token_fn = boundary_token_fn
+ self.lm_boundary_token_fn = self.boundary_token_fn
from jiant.models import input_module_uses_pair_embedding, input_module_uses_mirrored_pair
diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py
index 6821441ef..b6607382b 100644
--- a/jiant/tasks/edge_probing.py
+++ b/jiant/tasks/edge_probing.py
@@ -182,7 +182,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) ->
"""Convert a single record to an AllenNLP Instance."""
tokens = record["text"].split() # already space-tokenized by Moses
tokens = model_preprocessing_interface.boundary_token_fn(
- tokens
+ tokens, trunc_strategy="trunc_s1", trunc_side="right"
) # apply model-appropriate variants of [cls] and [sep].
text_field = sentence_to_text_field(tokens, indexers)
diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py
index 24da6d68f..7b2492908 100644
--- a/jiant/tasks/qa.py
+++ b/jiant/tasks/qa.py
@@ -120,11 +120,16 @@ def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx):
d["ans_idx"] = MetadataField(ans_idx)
d["idx"] = MetadataField(ans_idx) # required by evaluate()
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(para, question + answer)
+ inp = model_preprocessing_interface.boundary_token_fn(
+ para, s2=question + answer, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["psg_qst_ans"] = sentence_to_text_field(inp, indexers)
else:
d["psg"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(passage), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ passage, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["qst"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(question), indexers
@@ -308,11 +313,16 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx):
d["ans_idx"] = MetadataField(ans_idx)
d["idx"] = MetadataField(ans_idx) # required by evaluate()
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(psg, qst)
+ inp = model_preprocessing_interface.boundary_token_fn(
+ psg, s2=qst, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["psg_qst_ans"] = sentence_to_text_field(inp, indexers)
else:
d["psg"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(psg), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ psg, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["qst"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(qst), indexers
@@ -458,12 +468,19 @@ def _make_instance(example):
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn(
- example["passage"], example["question"], get_offset=True
+ example["passage"],
+ s2=example["question"],
+ trunc_strategy="trunc_s1",
+ trunc_side="right",
+ get_offset=True,
)
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["passage"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ example["passage"], trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(example["question"]), indexers
@@ -614,12 +631,19 @@ def _make_instance(example):
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn(
- example["passage"], example["question"], get_offset=True
+ example["passage"],
+ s2=example["question"],
+ trunc_strategy="trunc_s1",
+ trunc_side="right",
+ get_offset=True,
)
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["passage"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ example["passage"], trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(example["question"]), indexers
@@ -885,11 +909,16 @@ def _make_instance(question, choices, label, id_str):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(question), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ question, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(question, choice)
+ model_preprocessing_interface.boundary_token_fn(
+ question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -983,11 +1012,16 @@ def _make_instance(context, choices, label, id_str):
d["context_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["context"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(context), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ context, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(context, choice)
+ model_preprocessing_interface.boundary_token_fn(
+ context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py
index a21a6cb34..512a0b37d 100644
--- a/jiant/tasks/tasks.py
+++ b/jiant/tasks/tasks.py
@@ -108,18 +108,25 @@ def _make_instance(input1, input2, labels, idx):
d = {}
d["sent1_str"] = MetadataField(" ".join(input1))
if model_preprocessing_interface.model_flags["uses_pair_embedding"] and is_pair:
- inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
+ inp = model_preprocessing_interface.boundary_token_fn(
+ input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["inputs"] = sentence_to_text_field(inp, indexers)
d["sent2_str"] = MetadataField(" ".join(input2))
if (
model_preprocessing_interface.model_flags["uses_mirrored_pair"]
and is_symmetrical_pair
):
- inp_m = model_preprocessing_interface.boundary_token_fn(input1, input2)
+ inp_m = model_preprocessing_interface.boundary_token_fn(
+ input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["inputs_m"] = sentence_to_text_field(inp_m, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
if input2:
d["input2"] = sentence_to_text_field(
@@ -788,7 +795,10 @@ def _make_instance(input1, labels, tagids):
""" from multiple types in one column create multiple fields """
d = {}
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["sent1_str"] = MetadataField(" ".join(input1))
d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True)
@@ -1621,11 +1631,16 @@ def _make_instance(input1, input2, label, idx, lex_sem, pr_ar_str, logic, knowle
""" from multiple types in one column create multiple fields """
d = {}
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
+ inp = model_preprocessing_interface.boundary_token_fn(
+ input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -1829,12 +1844,17 @@ def _make_instance(input1, input2, labels, idx, pair_id):
d = {}
d["sent1_str"] = MetadataField(" ".join(input1))
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
+ inp = model_preprocessing_interface.boundary_token_fn(
+ input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
+ )
d["inputs"] = sentence_to_text_field(inp, indexers)
d["sent2_str"] = MetadataField(" ".join(input2))
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
if input2:
d["input2"] = sentence_to_text_field(
@@ -2191,7 +2211,10 @@ def process_split(
def _make_instance(input1, input2, labels):
d = {}
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -2291,12 +2314,15 @@ def _make_instance(input1, input2, labels):
d = {}
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp = model_preprocessing_interface.boundary_token_fn(
- input1, input2
+ input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
) # drop leading [CLS] token
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -2435,7 +2461,10 @@ def process_split(
def _make_instance(input1, input2, target, mask):
d = {}
d["inputs"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(input1), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ input1, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
d["sent1_str"] = MetadataField(" ".join(input1))
d["targs"] = sentence_to_text_field(target, self.target_indexer)
@@ -2617,7 +2646,9 @@ def _make_span_field(self, s, text_field, offset=1):
def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]:
"""Convert a single record to an AllenNLP Instance."""
tokens = record["text"].split()
- tokens, offset = model_preprocessing_interface.boundary_token_fn(tokens, get_offset=True)
+ tokens, offset = model_preprocessing_interface.boundary_token_fn(
+ tokens, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True
+ )
text_field = sentence_to_text_field(tokens, indexers)
example = {}
@@ -2835,12 +2866,16 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx):
d["sent2_str"] = MetadataField(" ".join(input2))
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn(
- input1, input2, get_offset=True
+ input1,
+ s2=input2,
+ trunc_strategy="trunc_s1",
+ trunc_side="right",
+ get_offset=True,
)
d["inputs"] = sentence_to_text_field(inp[: self.max_seq_len], indexers)
else:
inp1, offset1 = model_preprocessing_interface.boundary_token_fn(
- input1, get_offset=True
+ input1, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True
)
inp2, offset2 = model_preprocessing_interface.boundary_token_fn(
input2, get_offset=True
@@ -2968,11 +3003,16 @@ def _make_instance(context, choices, question, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(context), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ context, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(context, question + choice)
+ model_preprocessing_interface.boundary_token_fn(
+ context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3131,11 +3171,16 @@ def _make_instance(context, choices, question, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(context), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ context, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(context, question + choice)
+ model_preprocessing_interface.boundary_token_fn(
+ context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3217,11 +3262,16 @@ def _make_instance(question, choices, label, idx):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(question), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ question, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(question, choice)
+ model_preprocessing_interface.boundary_token_fn(
+ question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3308,11 +3358,16 @@ def _make_instance(question, choices, label, idx):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(question), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ question, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(question, choice)
+ model_preprocessing_interface.boundary_token_fn(
+ question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3431,14 +3486,17 @@ def _make_instance(d, idx):
new_d["passage_str"] = MetadataField(" ".join(d["passage"]))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
new_d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(d["passage"]), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ d["passage"], trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
new_d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(d["question"]), indexers
)
else: # BERT/XLNet
psg_qst = model_preprocessing_interface.boundary_token_fn(
- d["passage"], d["question"]
+ d["passage"], s2=d["question"], trunc_strategy="trunc_s1", trunc_side="right"
)
new_d["inputs"] = sentence_to_text_field(psg_qst, indexers)
new_d["labels"] = LabelField(d["label"], label_namespace="labels", skip_indexing=True)
@@ -3561,7 +3619,7 @@ def _make_instance(obs1, hyp1, hyp2, obs2, label, idx):
else:
for hyp_idx, hyp in enumerate([hyp1, hyp2]):
inp = (
- model_preprocessing_interface.boundary_token_fn(obs1 + hyp, obs2)
+ model_preprocessing_interface.boundary_token_fn(obs1 + hyp, s2=obs2)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(hyp)
)
@@ -3709,11 +3767,16 @@ def _make_instance(context, choices, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(context), indexers
+ model_preprocessing_interface.boundary_token_fn(
+ context, trunc_strategy="trunc_s1", trunc_side="right"
+ ),
+ indexers,
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(context, choice)
+ model_preprocessing_interface.boundary_token_fn(
+ context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
+ )
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py
index 3dca1174c..34c87a428 100644
--- a/jiant/utils/utils.py
+++ b/jiant/utils/utils.py
@@ -98,7 +98,7 @@ def select_pool_type(args):
return pool_type
-def apply_standard_boundary_tokens(s1, s2=None):
+def apply_standard_boundary_tokens(s1, *, s2=None):
"""Apply and to sequences of string-valued tokens.
Corresponds to more complex functions used with models like XLNet and BERT.
"""
diff --git a/tests/test_huggingface_transformers_interface.py b/tests/test_huggingface_transformers_interface.py
index de5e0dc7e..bb6eabdde 100644
--- a/tests/test_huggingface_transformers_interface.py
+++ b/tests/test_huggingface_transformers_interface.py
@@ -23,7 +23,7 @@ def test_bert_apply_boundary_tokens(self):
BertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"]
)
self.assertListEqual(
- BertEmbedderModule.apply_boundary_tokens(s1, s2),
+ BertEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"],
)
@@ -34,7 +34,7 @@ def test_roberta_apply_boundary_tokens(self):
RobertaEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""]
)
self.assertListEqual(
- RobertaEmbedderModule.apply_boundary_tokens(s1, s2),
+ RobertaEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["", "A", "B", "C", "", "", "D", "E", ""],
)
@@ -45,7 +45,7 @@ def test_albert_apply_boundary_tokens(self):
AlbertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"]
)
self.assertListEqual(
- AlbertEmbedderModule.apply_boundary_tokens(s1, s2),
+ AlbertEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"],
)
@@ -56,7 +56,7 @@ def test_xlnet_apply_boundary_tokens(self):
XLNetEmbedderModule.apply_boundary_tokens(s1), ["A", "B", "C", "", ""]
)
self.assertListEqual(
- XLNetEmbedderModule.apply_boundary_tokens(s1, s2),
+ XLNetEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["A", "B", "C", "", "D", "E", "", ""],
)
@@ -68,7 +68,7 @@ def test_gpt_apply_boundary_tokens(self):
["", "A", "B", "C", ""],
)
self.assertListEqual(
- OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2),
+ OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["", "A", "B", "C", "", "D", "E", ""],
)
@@ -79,7 +79,7 @@ def test_xlm_apply_boundary_tokens(self):
XLMEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""]
)
self.assertListEqual(
- XLMEmbedderModule.apply_boundary_tokens(s1, s2),
+ XLMEmbedderModule.apply_boundary_tokens(s1, s2=s2),
["", "A", "B", "C", "", "D", "E", ""],
)
diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py
index ef95a8031..2e2c6bda1 100644
--- a/tests/test_preprocess.py
+++ b/tests/test_preprocess.py
@@ -8,8 +8,14 @@
from unittest import mock
import jiant.tasks.tasks as tasks
+from jiant.utils import config
from jiant.utils.config import params_from_file
-from jiant.preprocess import get_task_without_loading_data, build_indexers, get_vocab
+from jiant.preprocess import (
+ get_task_without_loading_data,
+ build_indexers,
+ get_vocab,
+ ModelPreprocessingInterface,
+)
class TestProprocess(unittest.TestCase):
@@ -75,3 +81,118 @@ def test_build_vocab(self):
assert set(vocab.get_index_to_token_vocabulary("chars").values()) == set(
["@@PADDING@@", "@@UNKNOWN@@", "a", "b", "c"]
)
+
+
+class TestModelPreprocessingInterface(unittest.TestCase):
+ def test_max_tokens_limited_by_small_max_seq_len(self):
+ args = config.Params(max_seq_len=7, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ self.assertEqual(mpi.max_tokens, 7)
+
+ def test_max_tokens_limited_by_bert_model_max(self):
+ args = config.Params(max_seq_len=None, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ self.assertEqual(mpi.max_tokens, 512)
+
+ def test_boundary_token_fn_trunc_w_default_strategies(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_s2",
+ trunc_side="right",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_s2_left_side(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_s2",
+ trunc_side="left",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "zoo", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_s2_right_side(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_s2",
+ trunc_side="right",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_s1_left_side(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_s1",
+ trunc_side="left",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_s1_right_side(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_s1",
+ trunc_side="right",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "Apple", "[SEP]", "Xray", "you", "zoo", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_both(self):
+ MAX_SEQ_LEN = 7
+ args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(
+ ["Apple", "buy", "call"],
+ s2=["Xray", "you", "zoo"],
+ trunc_strategy="trunc_both",
+ trunc_side="left",
+ )
+ self.assertEqual(len(seq), MAX_SEQ_LEN)
+ self.assertEqual(seq, ["[CLS]", "buy", "call", "[SEP]", "you", "zoo", "[SEP]"])
+
+ def test_boundary_token_fn_trunc_with_short_sequence_and_no_max_seq_len(self):
+ args = config.Params(max_seq_len=None, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ seq = mpi.boundary_token_fn(["Apple", "buy", "call"], s2=["Xray", "you", "zoo"])
+ self.assertEqual(
+ seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"]
+ )
+
+ def test_boundary_token_fn_throws_exception_when_trunc_is_needed_but_strat_unspecified(self):
+ args = config.Params(max_seq_len=1, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ self.assertRaises(ValueError, mpi.boundary_token_fn, ["Apple", "buy", "call"], s2=["Xray"])
+
+ def test_boundary_token_fn_throws_exception_when_trunc_beyond_input_length(self):
+ args = config.Params(max_seq_len=1, input_module="bert-base-uncased")
+ mpi = ModelPreprocessingInterface(args)
+ self.assertRaises(
+ AssertionError,
+ mpi.boundary_token_fn,
+ ["Apple", "buy", "call"],
+ s2=["Xray"],
+ trunc_strategy="trunc_s2",
+ trunc_side="right",
+ )