Skip to content

Commit

Permalink
Revert "Address max sequence length issue (#1002) (#1006)" (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
pyeres authored Feb 25, 2020
1 parent 61211f3 commit 5d7a1c8
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 428 deletions.
18 changes: 9 additions & 9 deletions jiant/huggingface_transformers_interface/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_seg_ids(self, token_ids, input_mask):
return seg_ids

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
"""
A function that appliese the appropriate EOS/SOS/SEP/CLS tokens to token sequence or
token sequence pair for most tasks.
Expand Down Expand Up @@ -274,7 +274,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# BERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
Expand Down Expand Up @@ -331,7 +331,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# RoBERTa-style boundary token padding on string token sequences
if s2:
s = ["<s>"] + s1 + ["</s>", "</s>"] + s2 + ["</s>"]
Expand Down Expand Up @@ -385,7 +385,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# ALBERT-style boundary token padding on string token sequences
if s2:
s = ["[CLS]"] + s1 + ["[SEP]"] + s2 + ["[SEP]"]
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(self, args):
self._SEG_ID_SEP = 3

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# XLNet-style boundary token marking on string token sequences
if s2:
s = s1 + ["<sep>"] + s2 + ["<sep>", "<cls>"]
Expand Down Expand Up @@ -506,7 +506,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# OpenAI-GPT-style boundary token marking on string token sequences
if s2:
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
Expand Down Expand Up @@ -568,7 +568,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# GPT-2-style boundary token marking on string token sequences
if s2:
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
Expand Down Expand Up @@ -630,7 +630,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# TransformerXL-style boundary token marking on string token sequences
if s2:
s = ["<start>"] + s1 + ["<delim>"] + s2 + ["<extract>"]
Expand Down Expand Up @@ -697,7 +697,7 @@ def __init__(self, args):
self.parameter_setup(args)

@staticmethod
def apply_boundary_tokens(s1, *, s2=None, get_offset=False):
def apply_boundary_tokens(s1, s2=None, get_offset=False):
# XLM-style boundary token marking on string token sequences
if s2:
s = ["</s>"] + s1 + ["</s>"] + s2 + ["</s>"]
Expand Down
155 changes: 3 additions & 152 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,202 +740,53 @@ class ModelPreprocessingInterface(object):
"""

@staticmethod
def _apply_boundary_tokens(apply_boundary_tokens, max_tokens):
"""
Takes a function that applies boundary tokens and modifies it to respect a max_seq_len arg.
Parameters
----------
apply_boundary_tokens : function
Takes a function that adds boundary tokens and implements the common boundary token
adding function interface demonstated in HuggingfaceTransformersEmbedderModule.
max_tokens : int
The maximum number of tokens to allow in the output sequence.
Returns
-------
apply_boundary_tokens_with_trunc_strategy : function
Function w/ the common boundary token adding interface demonstrated in
HuggingfaceTransformersEmbedderModule, but also implementing a truncation strategy.
"""

def _apply_boundary_tokens_with_trunc_strategy(
*args, trunc_strategy=None, trunc_side=None, **kwargs
):
"""
Calls the apply_boundary_tokens function provided in the parent function and, if the
output exceeds the max number of tokens, applies a truncation strategy to the inputs,
then re-applies the apply_boundary_tokens function on the truncated inputs and returns
the result.
Parameters
----------
args : tuple
see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
huggingface_transformers_interface.modules
trunc_strategy : str
Which strings to truncate. Options:
`trunc_s1`: select s1 for truncation.
`trunc_s2`: select s2 for truncation.
`trunc_both`: truncate s1 and s2 equally.
trunc_side : str
Options are `right` or `left`. Indicates which side of the sequence to remove from.
e.g., if `left`, then the first element to be removed from ["A", "B", "C"] is "A".
kwargs : dict
see args docs for apply_boundary_tokens and apply_lm_boundary_tokens in
huggingface_transformers_interface.modules
Returns
-------
seq_w_boundry_tokens : List[str]
List of tokens returned by applying the specified apply_boundary_tokens function
to the inputs and using the specified truncation strategy.
Raises
------
ValueError
If sequence requires truncation but trunc_strategy or trunc_side are unspecified
or invalid.
"""
seq_w_boundry_tokens = apply_boundary_tokens(*args, **kwargs)
# if after calling the apply_boundary_tokens function the number of tokens is greater
# than the model's max, a truncation strategy can reduce the number of tokens:
num_excess_tokens = len(seq_w_boundry_tokens) - max_tokens
if num_excess_tokens > 0:
if not (trunc_strategy and trunc_side):
raise ValueError(
"Input(s) length will exceed model capacity or max_seq_len. "
+ "Adjust settings as necessary, or, to automatically truncate "
+ "inputs in `apply_boundary_tokens`, call `apply_boundary_tokens` "
+ "with `trunc_strategy` and `trunc_side` keyword args."
)
if trunc_strategy == "trunc_s2":
s2 = kwargs["s2"]
if trunc_side == "right":
s2_truncated = s2[:-num_excess_tokens]
elif trunc_side == "left":
s2_truncated = s2[num_excess_tokens:]
log.info(
"Before truncation s2 length = "
+ str(len(s2))
+ ", after truncation s2 length = "
+ str(len(s2_truncated))
)
assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
args = list(args)
kwargs["s2"] = s2_truncated
return apply_boundary_tokens(*args, **kwargs)
elif trunc_strategy == "trunc_both":
s1 = args[0]
s2 = kwargs["s2"]
if trunc_side == "right":
s1_truncated = s1[: -num_excess_tokens // 2]
s2_truncated = s2[: -num_excess_tokens // 2]
elif trunc_side == "left":
s1_truncated = s1[num_excess_tokens // 2 :]
s2_truncated = s2[num_excess_tokens // 2 :]
log.info(
"Before truncation s1 length = "
+ str(len(s1))
+ " and s2 length = "
+ str(len(s2))
+ ". After truncation s1 length = "
+ str(len(s1_truncated))
+ " and s2 length = "
+ str(len(s2_truncated))
)
assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
assert len(s2_truncated) > 0, "After truncation, s2 length would be 0."
args = list(args)
args[0] = s1_truncated
kwargs["s2"] = s2_truncated
return apply_boundary_tokens(*args, **kwargs)
elif trunc_strategy == "trunc_s1":
s1 = args[0]
if trunc_side == "right":
s1_truncated = s1[:-num_excess_tokens]
elif trunc_side == "left":
s1_truncated = s1[num_excess_tokens:]
log.info(
"Before truncation s1 length = "
+ str(len(s1))
+ ", after truncation s1 length = "
+ str(len(s1_truncated))
)
assert len(s1_truncated) > 0, "After truncation, s1 length would be 0."
args = list(args)
args[0] = s1_truncated
return apply_boundary_tokens(*args, **kwargs)
else:
raise ValueError(trunc_strategy + " is not a valid truncation strategy.")
else:
return seq_w_boundry_tokens

return _apply_boundary_tokens_with_trunc_strategy

def __init__(self, args):
boundary_token_fn = None
lm_boundary_token_fn = None
max_pos: int = None

if args.input_module.startswith("bert-"):
from jiant.huggingface_transformers_interface.modules import BertEmbedderModule

boundary_token_fn = BertEmbedderModule.apply_boundary_tokens
max_pos = BertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("roberta-"):
from jiant.huggingface_transformers_interface.modules import RobertaEmbedderModule

boundary_token_fn = RobertaEmbedderModule.apply_boundary_tokens
max_pos = RobertaTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("albert-"):
from jiant.huggingface_transformers_interface.modules import AlbertEmbedderModule

boundary_token_fn = AlbertEmbedderModule.apply_boundary_tokens
max_pos = AlbertTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlnet-"):
from jiant.huggingface_transformers_interface.modules import XLNetEmbedderModule

boundary_token_fn = XLNetEmbedderModule.apply_boundary_tokens
max_pos = XLNetTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("openai-gpt"):
from jiant.huggingface_transformers_interface.modules import OpenAIGPTEmbedderModule

boundary_token_fn = OpenAIGPTEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = OpenAIGPTEmbedderModule.apply_lm_boundary_tokens
max_pos = OpenAIGPTTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("gpt2"):
from jiant.huggingface_transformers_interface.modules import GPT2EmbedderModule

boundary_token_fn = GPT2EmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = GPT2EmbedderModule.apply_lm_boundary_tokens
max_pos = GPT2Tokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("transfo-xl-"):
from jiant.huggingface_transformers_interface.modules import TransfoXLEmbedderModule

boundary_token_fn = TransfoXLEmbedderModule.apply_boundary_tokens
lm_boundary_token_fn = TransfoXLEmbedderModule.apply_lm_boundary_tokens
max_pos = TransfoXLTokenizer.max_model_input_sizes.get(args.input_module, None)
elif args.input_module.startswith("xlm-"):
from jiant.huggingface_transformers_interface.modules import XLMEmbedderModule

boundary_token_fn = XLMEmbedderModule.apply_boundary_tokens
max_pos = XLMTokenizer.max_model_input_sizes.get(args.input_module, None)
else:
boundary_token_fn = utils.apply_standard_boundary_tokens

self.max_tokens = min(x for x in [max_pos, args.max_seq_len] if x is not None)

self.boundary_token_fn = self._apply_boundary_tokens(boundary_token_fn, self.max_tokens)

self.boundary_token_fn = boundary_token_fn
if lm_boundary_token_fn is not None:
self.lm_boundary_token_fn = self._apply_boundary_tokens(
lm_boundary_token_fn, args.max_seq_len
)
self.lm_boundary_token_fn = lm_boundary_token_fn
else:
self.lm_boundary_token_fn = self.boundary_token_fn
self.lm_boundary_token_fn = boundary_token_fn

from jiant.models import input_module_uses_pair_embedding, input_module_uses_mirrored_pair

Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/edge_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def make_instance(self, record, idx, indexers, model_preprocessing_interface) ->
"""Convert a single record to an AllenNLP Instance."""
tokens = record["text"].split() # already space-tokenized by Moses
tokens = model_preprocessing_interface.boundary_token_fn(
tokens, trunc_strategy="trunc_s1", trunc_side="right"
tokens
) # apply model-appropriate variants of [cls] and [sep].
text_field = sentence_to_text_field(tokens, indexers)

Expand Down
Loading

0 comments on commit 5d7a1c8

Please sign in to comment.