diff --git a/jiant/huggingface_transformers_interface/modules.py b/jiant/huggingface_transformers_interface/modules.py
index 8b10998e7..d8957e0cd 100644
--- a/jiant/huggingface_transformers_interface/modules.py
+++ b/jiant/huggingface_transformers_interface/modules.py
@@ -189,7 +189,7 @@ def get_seg_ids(self, token_ids, input_mask):
return seg_ids
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
"""
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or
token sequence pair for most tasks.
@@ -274,7 +274,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# BERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
@@ -331,7 +331,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# RoBERTa-style boundary token padding on string token sequences
if s2:
s = [""] + s1 + ["", ""] + s2 + [""]
@@ -385,7 +385,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# ALBERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
@@ -448,7 +448,7 @@ def __init__(self, args):
self._SEG_ID_SEP = 3
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# XLNet-style boundary token marking on string token sequences
if s2:
s = s1 + [""] + s2 + ["", ""]
@@ -506,7 +506,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# OpenAI-GPT-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -568,7 +568,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# GPT-2-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -630,7 +630,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# TransformerXL-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
@@ -697,7 +697,7 @@ def __init__(self, args):
self.parameter_setup(args)
@staticmethod
- def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
+ def apply_boundary_tokens(s1, s2=None, get_offset=False):
# XLM-style boundary token marking on string token sequences
if s2:
s = [""] + s1 + [""] + s2 + [""]
diff --git a/jiant/preprocess.py b/jiant/preprocess.py
index 815c383d6..27e3b6e18 100644
--- a/jiant/preprocess.py
+++ b/jiant/preprocess.py
@@ -740,202 +740,53 @@ class ModelPreprocessingInterface(object):
"""
- @staticmethod
- def _apply_boundary_tokens(apply_boundary_tokens, max_tokens):
- """
- Takes a function that applies boundary tokens and modifies it to respect a max_seq_len arg.
-
- Parameters
- ----------
- apply_boundary_tokens : function
- Takes a function that adds boundary tokens and implements the common boundary token
- adding function interface demonstated in HuggingfaceTransformersEmbedderModule.
- max_tokens : int
- The maximum number of tokens to allow in the output sequence.
-
- Returns
- -------
- apply_boundary_tokens_with_trunc_strategy : function
- Function w/ the common boundary token adding interface demonstrated in
- HuggingfaceTransformersEmbedderModule, but also implementing a truncation strategy.
- """
-
- def _apply_boundary_tokens_with_trunc_strategy(
- *args, trunc_strategy=None, trunc_side=None, **kwargs
- ):
- """
- Calls the apply_boundary_tokens function provided in the parent function and, if the
- output exceeds the max number of tokens, applies a truncation strategy to the inputs,
- then re-applies the apply_boundary_tokens function on the truncated inputs and returns
- the result.
-
- Parameters
- ----------
- args : tuple
- see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
- huggingface_transformers_interface.modules
- trunc_strategy : str
- Which strings to truncate. Options:
- `trunc_s1`: select s1 for truncation.
- `trunc_s2`: select s2 for truncation.
- `trunc_both`: truncate s1 and s2 equally.
- trunc_side : str
- Options are `right` or `left`. Indicates which side of the sequence to remove from.
- e.g., if `left`, then the first element to be removed from ["A", "B", "C"] is "A".
- kwargs : dict
- see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
- huggingface_transformers_interface.modules
-
- Returns
- -------
- seq_w_boundry_tokens : List[str]
- List of tokens returned by applying the specified apply_boundary_tokens function
- to the inputs and using the specified truncation strategy.
-
- Raises
- ------
- ValueError
- If sequence requires truncation but trunc_strategy or trunc_side are unspecified
- or invalid.
-
- """
- seq_w_boundry_tokens = apply_boundary_tokens(*args, **kwargs)
- # if after calling the apply_boundary_tokens function the number of tokens is greater
- # than the model's max, a truncation strategy can reduce the number of tokens:
- num_excess_tokens = len(seq_w_boundry_tokens) - max_tokens
- if num_excess_tokens > 0:
- if not (trunc_strategy and trunc_side):
- raise ValueError(
- "Input(s) length will exceed model capacity or max_seq_len. "
- + "Adjust settings as necessary, or, to automatically truncate "
- + "inputs in `apply_boundary_tokens`, call `apply_boundary_tokens` "
- + "with `trunc_strategy` and `trunc_side` keyword args."
- )
- if trunc_strategy == "trunc_s2":
- s2 = kwargs["s2"]
- if trunc_side == "right":
- s2_truncated = s2[:-num_excess_tokens]
- elif trunc_side == "left":
- s2_truncated = s2[num_excess_tokens:]
- log.info(
- "Before truncation s2 length = "
- + str(len(s2))
- + ", after truncation s2 length = "
- + str(len(s2_truncated))
- )
- assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
- args = list(args)
- kwargs["s2"] = s2_truncated
- return apply_boundary_tokens(*args, **kwargs)
- elif trunc_strategy == "trunc_both":
- s1 = args[0]
- s2 = kwargs["s2"]
- if trunc_side == "right":
- s1_truncated = s1[: -num_excess_tokens // 2]
- s2_truncated = s2[: -num_excess_tokens // 2]
- elif trunc_side == "left":
- s1_truncated = s1[num_excess_tokens // 2 :]
- s2_truncated = s2[num_excess_tokens // 2 :]
- log.info(
- "Before truncation s1 length = "
- + str(len(s1))
- + " and s2 length = "
- + str(len(s2))
- + ". After truncation s1 length = "
- + str(len(s1_truncated))
- + " and s2 length = "
- + str(len(s2_truncated))
- )
- assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
- assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
- args = list(args)
- args[0] = s1_truncated
- kwargs["s2"] = s2_truncated
- return apply_boundary_tokens(*args, **kwargs)
- elif trunc_strategy == "trunc_s1":
- s1 = args[0]
- if trunc_side == "right":
- s1_truncated = s1[:-num_excess_tokens]
- elif trunc_side == "left":
- s1_truncated = s1[num_excess_tokens:]
- log.info(
- "Before truncation s1 length = "
- + str(len(s1))
- + ", after truncation s1 length = "
- + str(len(s1_truncated))
- )
- assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
- args = list(args)
- args[0] = s1_truncated
- return apply_boundary_tokens(*args, **kwargs)
- else:
- raise ValueError(trunc_strategy + " is not a valid truncation strategy.")
- else:
- return seq_w_boundry_tokens
-
- return _apply_boundary_tokens_with_trunc_strategy
-
def __init__(self, args):
boundary_token_fn = None
lm_boundary_token_fn = None
- max_pos: int = None
if args.input_module.startswith("bert-"):
from jiant.huggingface_transformers_interface.modules import BertEmbedderModule
boundary_token_fn = BertEmbedderModule.apply_boundary_tokens
- max_pos = BertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("roberta-"):
from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule
boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens
- max_pos = RobertaTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("albert-"):
from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule
boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens
- max_pos = AlbertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlnet-"):
from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule
boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens
- max_pos = XLNetTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("openai-gpt"):
from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule
boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens
- max_pos = OpenAIGPTTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("gpt2"):
from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule
boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens
- max_pos = GPT2Tokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("transfo-xl-"):
from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule
boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens
- max_pos = TransfoXLTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlm-"):
from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule
boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens
- max_pos = XLMTokenizer.max_model_input_sizes.get(args.input_module, None)
else:
boundary_token_fn = utils.apply_standard_boundary_tokens
- self.max_tokens = min(x for x in [max_pos, args.max_seq_len] if x is not None)
-
- self.boundary_token_fn = self._apply_boundary_tokens(boundary_token_fn, self.max_tokens)
-
+ self.boundary_token_fn = boundary_token_fn
if lm_boundary_token_fn is not None:
- self.lm_boundary_token_fn = self._apply_boundary_tokens(
- lm_boundary_token_fn, args.max_seq_len
- )
+ self.lm_boundary_token_fn = lm_boundary_token_fn
else:
- self.lm_boundary_token_fn = self.boundary_token_fn
+ self.lm_boundary_token_fn = boundary_token_fn
from jiant.models import input_module_uses_pair_embedding, input_module_uses_mirrored_pair
diff --git a/jiant/tasks/edge_probing.py b/jiant/tasks/edge_probing.py
index b6607382b..6821441ef 100644
--- a/jiant/tasks/edge_probing.py
+++ b/jiant/tasks/edge_probing.py
@@ -182,7 +182,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) ->
"""Convert a single record to an AllenNLP Instance."""
tokens = record["text"].split() # already space-tokenized by Moses
tokens = model_preprocessing_interface.boundary_token_fn(
- tokens, trunc_strategy="trunc_s1", trunc_side="right"
+ tokens
) # apply model-appropriate variants of [cls] and [sep].
text_field = sentence_to_text_field(tokens, indexers)
diff --git a/jiant/tasks/qa.py b/jiant/tasks/qa.py
index 7b2492908..24da6d68f 100644
--- a/jiant/tasks/qa.py
+++ b/jiant/tasks/qa.py
@@ -120,16 +120,11 @@ def _make_instance(passage, question, answer, label, par_idx, qst_idx, ans_idx):
d["ans_idx"] = MetadataField(ans_idx)
d["idx"] = MetadataField(ans_idx) # required by evaluate()
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(
- para, s2=question + answer, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp = model_preprocessing_interface.boundary_token_fn(para, question + answer)
d["psg_qst_ans"] = sentence_to_text_field(inp, indexers)
else:
d["psg"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- passage, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(passage), indexers
)
d["qst"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(question), indexers
@@ -313,16 +308,11 @@ def _make_instance(psg, qst, ans_str, label, psg_idx, qst_idx, ans_idx):
d["ans_idx"] = MetadataField(ans_idx)
d["idx"] = MetadataField(ans_idx) # required by evaluate()
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(
- psg, s2=qst, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp = model_preprocessing_interface.boundary_token_fn(psg, qst)
d["psg_qst_ans"] = sentence_to_text_field(inp, indexers)
else:
d["psg"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- psg, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(psg), indexers
)
d["qst"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(qst), indexers
@@ -468,19 +458,12 @@ def _make_instance(example):
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn(
- example["passage"],
- s2=example["question"],
- trunc_strategy="trunc_s1",
- trunc_side="right",
- get_offset=True,
+ example["passage"], example["question"], get_offset=True
)
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["passage"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- example["passage"], trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers
)
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(example["question"]), indexers
@@ -631,19 +614,12 @@ def _make_instance(example):
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, start_offset, _ = model_preprocessing_interface.boundary_token_fn(
- example["passage"],
- s2=example["question"],
- trunc_strategy="trunc_s1",
- trunc_side="right",
- get_offset=True,
+ example["passage"], example["question"], get_offset=True
)
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["passage"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- example["passage"], trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(example["passage"]), indexers
)
d["question"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(example["question"]), indexers
@@ -909,16 +885,11 @@ def _make_instance(question, choices, label, id_str):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- question, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(question), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(question, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -1012,16 +983,11 @@ def _make_instance(context, choices, label, id_str):
d["context_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["context"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- context, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(context), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(context, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
diff --git a/jiant/tasks/tasks.py b/jiant/tasks/tasks.py
index 512a0b37d..a21a6cb34 100644
--- a/jiant/tasks/tasks.py
+++ b/jiant/tasks/tasks.py
@@ -108,25 +108,18 @@ def _make_instance(input1, input2, labels, idx):
d = {}
d["sent1_str"] = MetadataField(" ".join(input1))
if model_preprocessing_interface.model_flags["uses_pair_embedding"] and is_pair:
- inp = model_preprocessing_interface.boundary_token_fn(
- input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
d["inputs"] = sentence_to_text_field(inp, indexers)
d["sent2_str"] = MetadataField(" ".join(input2))
if (
model_preprocessing_interface.model_flags["uses_mirrored_pair"]
and is_symmetrical_pair
):
- inp_m = model_preprocessing_interface.boundary_token_fn(
- input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp_m = model_preprocessing_interface.boundary_token_fn(input1, input2)
d["inputs_m"] = sentence_to_text_field(inp_m, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
if input2:
d["input2"] = sentence_to_text_field(
@@ -795,10 +788,7 @@ def _make_instance(input1, labels, tagids):
""" from multiple types in one column create multiple fields """
d = {}
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
d["sent1_str"] = MetadataField(" ".join(input1))
d["labels"] = LabelField(labels, label_namespace="labels", skip_indexing=True)
@@ -1631,16 +1621,11 @@ def _make_instance(input1, input2, label, idx, lex_sem, pr_ar_str, logic, knowle
""" from multiple types in one column create multiple fields """
d = {}
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(
- input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -1844,17 +1829,12 @@ def _make_instance(input1, input2, labels, idx, pair_id):
d = {}
d["sent1_str"] = MetadataField(" ".join(input1))
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
- inp = model_preprocessing_interface.boundary_token_fn(
- input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ inp = model_preprocessing_interface.boundary_token_fn(input1, input2)
d["inputs"] = sentence_to_text_field(inp, indexers)
d["sent2_str"] = MetadataField(" ".join(input2))
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
if input2:
d["input2"] = sentence_to_text_field(
@@ -2211,10 +2191,7 @@ def process_split(
def _make_instance(input1, input2, labels):
d = {}
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -2314,15 +2291,12 @@ def _make_instance(input1, input2, labels):
d = {}
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp = model_preprocessing_interface.boundary_token_fn(
- input1, s2=input2, trunc_strategy="trunc_s1", trunc_side="right"
+ input1, input2
) # drop leading [CLS] token
d["inputs"] = sentence_to_text_field(inp, indexers)
else:
d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(input2), indexers
@@ -2461,10 +2435,7 @@ def process_split(
def _make_instance(input1, input2, target, mask):
d = {}
d["inputs"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(input1), indexers
)
d["sent1_str"] = MetadataField(" ".join(input1))
d["targs"] = sentence_to_text_field(target, self.target_indexer)
@@ -2646,9 +2617,7 @@ def _make_span_field(self, s, text_field, offset=1):
def make_instance(self, record, idx, indexers, model_preprocessing_interface) -> Type[Instance]:
"""Convert a single record to an AllenNLP Instance."""
tokens = record["text"].split()
- tokens, offset = model_preprocessing_interface.boundary_token_fn(
- tokens, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True
- )
+ tokens, offset = model_preprocessing_interface.boundary_token_fn(tokens, get_offset=True)
text_field = sentence_to_text_field(tokens, indexers)
example = {}
@@ -2866,16 +2835,12 @@ def _make_instance(input1, input2, idxs1, idxs2, labels, idx):
d["sent2_str"] = MetadataField(" ".join(input2))
if model_preprocessing_interface.model_flags["uses_pair_embedding"]:
inp, offset1, offset2 = model_preprocessing_interface.boundary_token_fn(
- input1,
- s2=input2,
- trunc_strategy="trunc_s1",
- trunc_side="right",
- get_offset=True,
+ input1, input2, get_offset=True
)
d["inputs"] = sentence_to_text_field(inp[: self.max_seq_len], indexers)
else:
inp1, offset1 = model_preprocessing_interface.boundary_token_fn(
- input1, trunc_strategy="trunc_s1", trunc_side="right", get_offset=True
+ input1, get_offset=True
)
inp2, offset2 = model_preprocessing_interface.boundary_token_fn(
input2, get_offset=True
@@ -3003,16 +2968,11 @@ def _make_instance(context, choices, question, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- context, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(context), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(context, question + choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3171,16 +3131,11 @@ def _make_instance(context, choices, question, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- context, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(context), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- context, s2=question + choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(context, question + choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3262,16 +3217,11 @@ def _make_instance(question, choices, label, idx):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- question, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(question), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(question, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3358,16 +3308,11 @@ def _make_instance(question, choices, label, idx):
d["question_str"] = MetadataField(" ".join(question))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- question, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(question), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- question, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(question, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
@@ -3486,17 +3431,14 @@ def _make_instance(d, idx):
new_d["passage_str"] = MetadataField(" ".join(d["passage"]))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
new_d["input1"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- d["passage"], trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(d["passage"]), indexers
)
new_d["input2"] = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(d["question"]), indexers
)
else: # BERT/XLNet
psg_qst = model_preprocessing_interface.boundary_token_fn(
- d["passage"], s2=d["question"], trunc_strategy="trunc_s1", trunc_side="right"
+ d["passage"], d["question"]
)
new_d["inputs"] = sentence_to_text_field(psg_qst, indexers)
new_d["labels"] = LabelField(d["label"], label_namespace="labels", skip_indexing=True)
@@ -3619,7 +3561,7 @@ def _make_instance(obs1, hyp1, hyp2, obs2, label, idx):
else:
for hyp_idx, hyp in enumerate([hyp1, hyp2]):
inp = (
- model_preprocessing_interface.boundary_token_fn(obs1 + hyp, s2=obs2)
+ model_preprocessing_interface.boundary_token_fn(obs1 + hyp, obs2)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(hyp)
)
@@ -3767,16 +3709,11 @@ def _make_instance(context, choices, label, idx):
d["question_str"] = MetadataField(" ".join(context))
if not model_preprocessing_interface.model_flags["uses_pair_embedding"]:
d["question"] = sentence_to_text_field(
- model_preprocessing_interface.boundary_token_fn(
- context, trunc_strategy="trunc_s1", trunc_side="right"
- ),
- indexers,
+ model_preprocessing_interface.boundary_token_fn(context), indexers
)
for choice_idx, choice in enumerate(choices):
inp = (
- model_preprocessing_interface.boundary_token_fn(
- context, s2=choice, trunc_strategy="trunc_s1", trunc_side="right"
- )
+ model_preprocessing_interface.boundary_token_fn(context, choice)
if model_preprocessing_interface.model_flags["uses_pair_embedding"]
else model_preprocessing_interface.boundary_token_fn(choice)
)
diff --git a/jiant/utils/utils.py b/jiant/utils/utils.py
index 34c87a428..3dca1174c 100644
--- a/jiant/utils/utils.py
+++ b/jiant/utils/utils.py
@@ -98,7 +98,7 @@ def select_pool_type(args):
return pool_type
-def apply_standard_boundary_tokens(s1, *, s2=None):
+def apply_standard_boundary_tokens(s1, s2=None):
"""Apply and to sequences of string-valued tokens.
Corresponds to more complex functions used with models like XLNet and BERT.
"""
diff --git a/tests/test_huggingface_transformers_interface.py b/tests/test_huggingface_transformers_interface.py
index bb6eabdde..de5e0dc7e 100644
--- a/tests/test_huggingface_transformers_interface.py
+++ b/tests/test_huggingface_transformers_interface.py
@@ -23,7 +23,7 @@ def test_bert_apply_boundary_tokens(self):
BertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"]
)
self.assertListEqual(
- BertEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ BertEmbedderModule.apply_boundary_tokens(s1, s2),
["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"],
)
@@ -34,7 +34,7 @@ def test_roberta_apply_boundary_tokens(self):
RobertaEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""]
)
self.assertListEqual(
- RobertaEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ RobertaEmbedderModule.apply_boundary_tokens(s1, s2),
["", "A", "B", "C", "", "", "D", "E", ""],
)
@@ -45,7 +45,7 @@ def test_albert_apply_boundary_tokens(self):
AlbertEmbedderModule.apply_boundary_tokens(s1), ["[CLS]", "A", "B", "C", "[SEP]"]
)
self.assertListEqual(
- AlbertEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ AlbertEmbedderModule.apply_boundary_tokens(s1, s2),
["[CLS]", "A", "B", "C", "[SEP]", "D", "E", "[SEP]"],
)
@@ -56,7 +56,7 @@ def test_xlnet_apply_boundary_tokens(self):
XLNetEmbedderModule.apply_boundary_tokens(s1), ["A", "B", "C", "", ""]
)
self.assertListEqual(
- XLNetEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ XLNetEmbedderModule.apply_boundary_tokens(s1, s2),
["A", "B", "C", "", "D", "E", "", ""],
)
@@ -68,7 +68,7 @@ def test_gpt_apply_boundary_tokens(self):
["", "A", "B", "C", ""],
)
self.assertListEqual(
- OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ OpenAIGPTEmbedderModule.apply_boundary_tokens(s1, s2),
["", "A", "B", "C", "", "D", "E", ""],
)
@@ -79,7 +79,7 @@ def test_xlm_apply_boundary_tokens(self):
XLMEmbedderModule.apply_boundary_tokens(s1), ["", "A", "B", "C", ""]
)
self.assertListEqual(
- XLMEmbedderModule.apply_boundary_tokens(s1, s2=s2),
+ XLMEmbedderModule.apply_boundary_tokens(s1, s2),
["", "A", "B", "C", "", "D", "E", ""],
)
diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py
index 2e2c6bda1..ef95a8031 100644
--- a/tests/test_preprocess.py
+++ b/tests/test_preprocess.py
@@ -8,14 +8,8 @@
from unittest import mock
import jiant.tasks.tasks as tasks
-from jiant.utils import config
from jiant.utils.config import params_from_file
-from jiant.preprocess import (
- get_task_without_loading_data,
- build_indexers,
- get_vocab,
- ModelPreprocessingInterface,
-)
+from jiant.preprocess import get_task_without_loading_data, build_indexers, get_vocab
class TestProprocess(unittest.TestCase):
@@ -81,118 +75,3 @@ def test_build_vocab(self):
assert set(vocab.get_index_to_token_vocabulary("chars").values()) == set(
["@@PADDING@@", "@@UNKNOWN@@", "a", "b", "c"]
)
-
-
-class TestModelPreprocessingInterface(unittest.TestCase):
- def test_max_tokens_limited_by_small_max_seq_len(self):
- args = config.Params(max_seq_len=7, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- self.assertEqual(mpi.max_tokens, 7)
-
- def test_max_tokens_limited_by_bert_model_max(self):
- args = config.Params(max_seq_len=None, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- self.assertEqual(mpi.max_tokens, 512)
-
- def test_boundary_token_fn_trunc_w_default_strategies(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_s2",
- trunc_side="right",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"])
-
- def test_boundary_token_fn_trunc_s2_left_side(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_s2",
- trunc_side="left",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "zoo", "[SEP]"])
-
- def test_boundary_token_fn_trunc_s2_right_side(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_s2",
- trunc_side="right",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "[SEP]"])
-
- def test_boundary_token_fn_trunc_s1_left_side(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_s1",
- trunc_side="left",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"])
-
- def test_boundary_token_fn_trunc_s1_right_side(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_s1",
- trunc_side="right",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "Apple", "[SEP]", "Xray", "you", "zoo", "[SEP]"])
-
- def test_boundary_token_fn_trunc_both(self):
- MAX_SEQ_LEN = 7
- args = config.Params(max_seq_len=MAX_SEQ_LEN, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(
- ["Apple", "buy", "call"],
- s2=["Xray", "you", "zoo"],
- trunc_strategy="trunc_both",
- trunc_side="left",
- )
- self.assertEqual(len(seq), MAX_SEQ_LEN)
- self.assertEqual(seq, ["[CLS]", "buy", "call", "[SEP]", "you", "zoo", "[SEP]"])
-
- def test_boundary_token_fn_trunc_with_short_sequence_and_no_max_seq_len(self):
- args = config.Params(max_seq_len=None, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- seq = mpi.boundary_token_fn(["Apple", "buy", "call"], s2=["Xray", "you", "zoo"])
- self.assertEqual(
- seq, ["[CLS]", "Apple", "buy", "call", "[SEP]", "Xray", "you", "zoo", "[SEP]"]
- )
-
- def test_boundary_token_fn_throws_exception_when_trunc_is_needed_but_strat_unspecified(self):
- args = config.Params(max_seq_len=1, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- self.assertRaises(ValueError, mpi.boundary_token_fn, ["Apple", "buy", "call"], s2=["Xray"])
-
- def test_boundary_token_fn_throws_exception_when_trunc_beyond_input_length(self):
- args = config.Params(max_seq_len=1, input_module="bert-base-uncased")
- mpi = ModelPreprocessingInterface(args)
- self.assertRaises(
- AssertionError,
- mpi.boundary_token_fn,
- ["Apple", "buy", "call"],
- s2=["Xray"],
- trunc_strategy="trunc_s2",
- trunc_side="right",
- )