Skip to content

Commit

Permalink
fix "\n" tokenization + phi-2 new layer names (OpenNMT#2552)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Jan 18, 2024
1 parent 8045a86 commit b67e492
Show file tree
Hide file tree
Showing 15 changed files with 126 additions and 85 deletions.
6 changes: 3 additions & 3 deletions eval_llm/MMLU/run_mmlu_opennmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def evaluate(opt):
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, task, k)
prompt = train_prompt + prompt_end
"""
while len(prompt.split()) > 768:

while len(prompt.split(" ")) > 768:
prompt_split = prompt.split("\n\n")
prompt_split.pop(1)
prompt = "\n\n".join(prompt_split)
"""

label = test_df.iloc[i, test_df.shape[1] - 1]
records.append({"prompt": prompt, "answer": label})
src.append(prompt.replace("\n", "⦅newline⦆"))
Expand Down
5 changes: 2 additions & 3 deletions eval_llm/WIKITEXT2/run_wikitext-2_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def evaluate(opt):
engine = InferenceEnginePY(engine_opt)

# Tokenize the dataset.
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
opt.src = "eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
tokenize_dataset(opt, context_length=512)

# Score the tokeznized dataset
Expand All @@ -140,8 +140,7 @@ def evaluate(opt):

def _get_parser():
parser = ArgumentParser(description="run_wikitext-2_benchmark.py")
opts.config_opts(parser)
opts.translate_opts(parser, dynamic=True)
opts.translate_opts(parser)
return parser


Expand Down
5 changes: 4 additions & 1 deletion onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
for i, model_decoder in enumerate(self.model_decoders)
]
)
mean_attns = self.combine_attns(attns)
if attns[0]["std"] is not None:
mean_attns = self.combine_attns(attns)
else:
mean_attns = attns
return EnsembleDecoderOutput(dec_outs), mean_attns

def combine_attns(self, attns):
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,20 @@ def __init__(

def _process(self, stream):
for i, example in enumerate(stream):
example["src"] = example["src"].strip("\n").split()
example["src_original"] = example["src_original"].strip("\n").split()
example["src"] = example["src"].strip().split(" ")
example["src_original"] = example["src_original"].strip().split(" ")
if "src_feats" in example:
example["src_feats"] = [
feat.strip("\n").split() for feat in example["src_feats"]
feat.strip().split(" ") for feat in example["src_feats"]
]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip("\n").split()
example["align"] = example["align"].strip().split(" ")
if example["tgt"] is not None:
example["tgt"] = example["tgt"].strip("\n").split()
example["tgt_original"] = example["tgt_original"].strip("\n").split()
example["tgt"] = example["tgt"].strip().split(" ")
example["tgt_original"] = example["tgt_original"].strip().split(" ")
if (
len(example["src"]) == 0
or len(example["tgt"]) == 0
Expand Down
20 changes: 11 additions & 9 deletions onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,33 @@ def numericalize(vocabs, example):
numeric = example
numeric["src"]["src_ids"] = []
if vocabs["data_task"] == ModelTask.SEQ2SEQ:
src_text = example["src"]["src"].split()
src_text = example["src"]["src"].split(" ")
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split()
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
[decoder_start_token] + tgt_text + [DefaultTokens.EOS]
)

elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
src_text = example["src"]["src"].split()
src_text = example["src"]["src"].split(" ")
if decoder_start_token != "":
src_text = [decoder_start_token] + src_text
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split()
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](tgt_text + [DefaultTokens.EOS])
if decoder_start_token == "":
numeric["tgt"]["tgt_ids"] = numeric["tgt"]["tgt_ids"][1:]
else:
raise ValueError(f"Something went wrong with task {vocabs['data_task']}")

if "feats" in example["src"]:
numeric_feats = []
for fv, feat in zip(vocabs["src_feats"], example["src"]["feats"]):
numeric_feats.append(fv(feat.split()))
numeric_feats.append(fv(feat.split(" ")))
numeric["src"]["feats"] = numeric_feats

return numeric
Expand Down Expand Up @@ -329,7 +331,7 @@ def textbatch_to_tensor(vocabs, batch, device, is_train=False):
infer_iter = []
for i, ex in enumerate(batch):
# Keep it consistent with dynamic data
ex["srclen"] = len(ex["src"]["src"].split())
ex["srclen"] = len(ex["src"]["src"].split(" "))
ex["in_in_bucket"] = i
ex["cid"] = "text"
ex["cid_line_number"] = i
Expand All @@ -354,7 +356,7 @@ def _addcopykeys(vocabs, example):
Returns:
``example``, changed as described.
"""
src = example["src"]["src"].split()
src = example["src"]["src"].split(" ")
src_ex_vocab = pyonmttok.build_vocab_from_tokens(
Counter(src),
maximum_size=0,
Expand All @@ -377,10 +379,10 @@ def _addcopykeys(vocabs, example):
if vocabs["data_task"] == ModelTask.SEQ2SEQ:
tgt = (
[DefaultTokens.UNK]
+ example["tgt"]["tgt"].split()
+ example["tgt"]["tgt"].split(" ")
+ [DefaultTokens.UNK]
)
elif vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
tgt = example["tgt"]["tgt"].split() + [DefaultTokens.UNK]
tgt = example["tgt"]["tgt"].split(" ") + [DefaultTokens.UNK]
example["alignment"] = src_ex_vocab(tgt)
return example
2 changes: 1 addition & 1 deletion onmt/transforms/fuzzymatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,6 @@ def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
assert len(src_segments) == len(fuzzied_src)
for idx, (example, _, _) in enumerate(batch):
if fuzzied_src[idx] != "":
example["src"] = fuzzied_src[idx].split()
example["src"] = fuzzied_src[idx].split(" ")

return batch
14 changes: 7 additions & 7 deletions onmt/transforms/inlinetags.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:
maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None
)

tokenized_source_string = source_only.split()
tokenized_target_string = tgt_example.split()
tokenized_source_string = source_only.split(" ")
tokenized_target_string = tgt_example.split(" ")

src_offset, tgt_offset = 0, 0
src_with_tags, tgt_with_tags = list(), list()
Expand Down Expand Up @@ -140,12 +140,12 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:

src_term = " ".join(
tokenized_source_string[
source_index : source_index + len(pair[0].split())
source_index : source_index + len(pair[0].split(" "))
]
)
tgt_term = " ".join(
tokenized_target_string[
target_index : target_index + len(pair[1].split())
target_index : target_index + len(pair[1].split(" "))
]
)

Expand Down Expand Up @@ -210,11 +210,11 @@ def _tagged_src_tgt(self, src_example, tgt_example) -> tuple:
tgt_with_tags.append(tgt_example[tgt_offset:])

return (
"".join(src_with_tags).replace("∥", " ").split(),
"".join(tgt_with_tags).replace("∥", " ").split(),
"".join(src_with_tags).replace("∥", " ").split(" "),
"".join(tgt_with_tags).replace("∥", " ").split(" "),
), is_match
else:
return (src_example.split(), tgt_example.split()), is_match
return (src_example.split(" "), tgt_example.split(" ")), is_match


@register_transform(name="inlinetags")
Expand Down
16 changes: 8 additions & 8 deletions onmt/transforms/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def get_specials(cls, opts):
prefix_dict = cls.get_prefix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, prefix in prefix_dict.items():
src_specials.update(prefix["src"].split())
tgt_specials.update(prefix["tgt"].split())
src_specials.update(prefix["src"].split(" "))
tgt_specials.update(prefix["tgt"].split(" "))
return (src_specials, tgt_specials)

def warm_up(self, vocabs=None):
Expand All @@ -149,9 +149,9 @@ def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
if example.get(side) is not None:
example[side] = side_prefix.split() + example[side]
example[side] = side_prefix.split(" ") + example[side]
elif len(side_prefix) > 0:
example[side] = side_prefix.split()
example[side] = side_prefix.split(" ")
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
Expand Down Expand Up @@ -250,8 +250,8 @@ def get_specials(cls, opts):
suffix_dict = cls.get_suffix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, suffix in suffix_dict.items():
src_specials.update(suffix["src"].split())
tgt_specials.update(suffix["tgt"].split())
src_specials.update(suffix["src"].split(" "))
tgt_specials.update(suffix["tgt"].split(" "))
return (src_specials, tgt_specials)

def warm_up(self, vocabs=None):
Expand All @@ -263,9 +263,9 @@ def _append(self, example, suffix):
"""Prepend `suffix` to `tokens`."""
for side, side_suffix in suffix.items():
if example.get(side) is not None:
example[side] = example[side] + side_suffix.split()
example[side] = example[side] + side_suffix.split(" ")
elif len(side_suffix) > 0:
example[side] = side_suffix.split()
example[side] = side_suffix.split(" ")
return example

def apply(self, example, is_train=False, stats=None, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions onmt/transforms/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
self.pre_dict[corpus_name],
self.post_dict[corpus_name],
)
example["src"] = src_str.split()
example["src"] = src_str.split(" ")

if example["tgt"] is not None:
tgt_str = self.mpn.normalize(
Expand All @@ -341,6 +341,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
self.pre_dict[corpus_name],
self.post_dict[corpus_name],
)
example["tgt"] = tgt_str.split()
example["tgt"] = tgt_str.split(" ")

return example
16 changes: 8 additions & 8 deletions onmt/transforms/terminology.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _create_internal_termbase(self, termbase_path):
for pair in pairs:
src_term, tgt_term = map(str, pair.split("\t"))
src_lemma = " ".join(
"∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term)
"∥".join(tok.lemma_.split(" ")) for tok in self.src_nlp(src_term)
).strip()
tgt_lemma = " ".join(
tok.lemma_ for tok in self.tgt_nlp(tgt_term)
Expand Down Expand Up @@ -93,7 +93,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:

# Perform tokenization with spacy for consistency.
tokenized_source = [tok.text for tok in doc_src]
lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src]
lemmatized_source = ["∥".join(tok.lemma_.lower().split(" ")) for tok in doc_src]
lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt]

lemmatized_source_string = " ".join(lemmatized_source)
Expand Down Expand Up @@ -143,7 +143,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
lemma_list_index += len(w) + 1

# We need to know if the term is multiword
num_words_in_src_term = len(src_entry.split())
num_words_in_src_term = len(src_entry.split(" "))
src_term = " ".join(
tokenized_source[
lemma_list_index : lemma_list_index + num_words_in_src_term
Expand All @@ -164,7 +164,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:

if is_match:
source_with_terms.append(lemmatized_source_string[offset:])
tokenized_source_with_terms = "".join(source_with_terms).split()
tokenized_source_with_terms = "".join(source_with_terms).split(" ")

if not (
len(tokenized_source)
Expand All @@ -173,7 +173,7 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
):
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
return fixed_punct.split(" "), not is_match

# Construct the final source from the lemmatized list
# that contains the terms. We compare the tokens in the
Expand All @@ -195,17 +195,17 @@ def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
final_string = " ".join(
completed_tokenized_source
+ [self.delimiter]
+ augmented_part.split()
+ augmented_part.split(" ")
)
else:
final_string = " ".join(completed_tokenized_source)

fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), is_match
return fixed_punct.split(" "), is_match
else:
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
return fixed_punct.split(" "), not is_match


@register_transform(name="terminology")
Expand Down
15 changes: 11 additions & 4 deletions onmt/transforms/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")

def _repr_args(self):
"""Return str represent key arguments for class."""
Expand Down Expand Up @@ -353,7 +353,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")


@register_transform(name="onmt_tokenize")
Expand Down Expand Up @@ -550,7 +550,14 @@ def tokenize_string(self, sentence, side="src", is_train=False):
self.maptable[b]
for b in sentence.replace(DefaultTokens.SEP, "\n").encode("utf-8")
)
segmented = tokenizer(sentence)
segmented1 = tokenizer(sentence)
segmented = []
# ugly patch to make sure "\n\n" is split in two items
for s in segmented1:
if s == "ĊĊ":
segmented.extend(["Ċ", "Ċ"])
else:
segmented.append(s)
else:
segmented = tokenizer(sentence)
return segmented
Expand All @@ -572,7 +579,7 @@ def apply_reverse(self, translated):
if isinstance(translated, list):
return self._detokenize(translated, "tgt")
else:
return self._detokenize(translated.split(), "tgt")
return self._detokenize(translated.split(" "), "tgt")

def _repr_args(self):
"""Return str represent key arguments for class."""
Expand Down
4 changes: 2 additions & 2 deletions onmt/transforms/uppercase.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
for c in unicodedata.normalize("NFD", src_str.upper())
if unicodedata.category(c) != "Mn"
)
example["src"] = src_str.split()
example["src"] = src_str.split(" ")

if example["tgt"] is not None:
tgt_str = " ".join(example["tgt"])
Expand All @@ -56,6 +56,6 @@ def apply(self, example, is_train=False, stats=None, **kwargs):
for c in unicodedata.normalize("NFD", tgt_str.upper())
if unicodedata.category(c) != "Mn"
)
example["tgt"] = tgt_str.split()
example["tgt"] = tgt_str.split(" ")

return example
Loading

0 comments on commit b67e492

Please sign in to comment.