Skip to content

Commit

Permalink
first implementation of id_tokenization (huggingface)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Oct 2, 2024
1 parent 4a3d0dd commit 09dcaba
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 52 deletions.
53 changes: 41 additions & 12 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@
"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig,
}

tok_table = {
"LlamaForCausalLM": "huggingface_tokenize",
"MistralForCausalLM": "mistral_tokenize",
"MixtralForCausalLM": "mistral_tokenize",
"PhiForCausalLM": "huggingface_tokenize",
"Phi3ForCausalLM": "huggingface_tokenize",
"GPT2LMHeadModel": "huggingface_tokenize",
"XLMRobertaXLForMaskedLM": "huggingface_tokenize",
}


class Tokenizer:
def __init__(self, model_path: str):
Expand Down Expand Up @@ -306,6 +316,7 @@ def run(cls, args):
else:
generation_config_json = None
else:
huggingface_model = args.model_dir
directory_path = args.output
os.makedirs(directory_path, exist_ok=True)
try:
Expand Down Expand Up @@ -1053,6 +1064,33 @@ def get_weight(checkpoint, tensor_name):
for merge in data["model"]["merges"]:
bpemodel.write(merge + "\n")

transforms = [
tok_table[arch]
] # , "filtertoolong"] # the filtertoolong transform is not plug-n-play with id_tokenize
if tok_table[arch] == "huggingface_tokenize":
transforms_configs = {
tok_table[arch]: {"max_length": 512},
}
elif tok_table[arch] == "mistral_tokenize":
transforms_configs = {
tok_table[arch]: {
"path": os.path.join("${MODEL_PATH}", tokenizer_basename)
}
}
else:
# not used right now, but keeping for reference
transforms_configs = {
"filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512},
"onmt_tokenize": {
"src_subword_type": src_subword_type,
"src_subword_model": os.path.join(
"${MODEL_PATH}", tokenizer_basename
),
"gpt2_pretok": gpt2_pretok,
"mapped_tokens": mapped_tokens,
},
}

vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
if add_bos_token:
Expand Down Expand Up @@ -1084,18 +1122,8 @@ def get_weight(checkpoint, tensor_name):
vocab_size_multiple=8,
decoder_start_token=vocabs["decoder_start_token"],
**vocabs["specials"],
transforms=["onmt_tokenize", "filtertoolong"],
transforms_configs={
"filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512},
"onmt_tokenize": {
"src_subword_type": src_subword_type,
"src_subword_model": os.path.join(
"${MODEL_PATH}", tokenizer_basename
),
"gpt2_pretok": gpt2_pretok,
"mapped_tokens": mapped_tokens,
},
},
transforms=transforms,
transforms_configs=transforms_configs,
model=arch_table[arch](
layers=n_layers,
hidden_size=hidden_size,
Expand All @@ -1122,6 +1150,7 @@ def get_weight(checkpoint, tensor_name):
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
left_pad=left_pad,
huggingface_model=huggingface_model,
),
training=TrainingConfig(
compute_dtype=compute_dtype,
Expand Down
13 changes: 12 additions & 1 deletion eole/config/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import Dict, List, Literal
from pydantic import Field, field_validator # , model_validator
from pydantic import Field, field_validator, model_validator
from pydantic import create_model

from eole import constants
Expand Down Expand Up @@ -332,3 +332,14 @@ def _validate_data_config(self, build_vocab_only=False):
# TrainConfig without existing files (e.g. inference)
# self._validate_vocab_config(build_vocab_only=build_vocab_only)
return self

@model_validator(mode="after")
def _maybe_set_huggingface_model(self):
if getattr(self, "model", None) is None:
return self
if self.model.huggingface_model is not None:
if hasattr(self.transforms_configs, "huggingface_tokenize"):
self.transforms_configs.huggingface_tokenize.huggingface_model = (
self.model.huggingface_model
)
return self
8 changes: 8 additions & 0 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
computed_field,
) # , TypeAdapter

import eole
from eole.constants import PositionEncodingType, ActivationFunction, ModelType
from eole.config.config import Config

Expand Down Expand Up @@ -438,6 +439,13 @@ class BaseModelConfig(Config):
left_pad: bool = Field(
default=False, description="Enable left-padding, useful for some LLMs."
)
huggingface_model: str | None = Field(
default=None, description="Original huggingface model."
)
eole_version: str | None = Field(
default=eole.__version__,
description="Eole version used to convert/train/save the model.",
)

# @computed_field()
# @property
Expand Down
8 changes: 6 additions & 2 deletions eole/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,13 @@ def _tuple_to_json_with_tokIDs(self, tuple_bucket):
tuple_bucket = transform_bucket(self.task, tuple_bucket, self.score_threshold)
for example in tuple_bucket:
if example is not None:
bucket.append(
numericalize(self.vocabs, example, model_type=self.model_type)
numericalized = numericalize(
self.vocabs, example, model_type=self.model_type
)
bucket.append(numericalized)
# print(numericalized)
# exit()

return bucket

def _add_indice(self, bucket):
Expand Down
47 changes: 29 additions & 18 deletions eole/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,45 @@ def transform_bucket(task, bucket, threshold=0):
def numericalize(vocabs, example, model_type=ModelType.ENCODER_DECODER):
""" """
decoder_start_token = vocabs["decoder_start_token"]
# print("decoder_start_token:", decoder_start_token)
# print(example)
numeric = example
numeric["src"]["src_ids"] = []
numeric["src"]["src_ids"] = example.get("src_ids", [])
maybe_tgt_ids = example.get("tgt_ids", [])
if model_type == ModelType.ENCODER_DECODER:
src_text = example["src"]["src"].split(" ")
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if numeric["src"]["src_ids"] == []:
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
[decoder_start_token]
+ tgt_text
+ [vocabs["specials"].get("eos_token", "")]
)
if maybe_tgt_ids != []:
numeric["tgt"]["tgt_ids"] = maybe_tgt_ids
else:
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
[decoder_start_token]
+ tgt_text
+ [vocabs["specials"].get("eos_token", "")]
)

elif model_type == ModelType.DECODER:
src_text = example["src"]["src"].split(" ")
if decoder_start_token != "":
src_text = [decoder_start_token] + src_text
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if numeric["src"]["src_ids"] == []:
src_text = example["src"]["src"].split(" ")
if decoder_start_token != "":
src_text = [decoder_start_token] + src_text
numeric["src"]["src_ids"] = vocabs["src"](src_text)
if example["tgt"] is not None:
numeric["tgt"]["tgt_ids"] = []
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
tgt_text + [vocabs["specials"].get("eos_token", "")]
)
if maybe_tgt_ids != []:
# decoder_start_token logic is supposedly handled in the tokenizer
numeric["tgt"]["tgt_ids"] = maybe_tgt_ids
else:
tgt_text = example["tgt"]["tgt"].split(" ")
numeric["tgt"]["tgt_ids"] = vocabs["tgt"](
tgt_text + [vocabs["specials"].get("eos_token", "")]
)
if decoder_start_token == "":
numeric["tgt"]["tgt_ids"] = numeric["tgt"]["tgt_ids"][1:]

# TODO: support id tokenization
elif model_type == ModelType.ENCODER:
src_text = example["src"]["src"].split(" ")
if example["tgt"] is not None: # TO BE DISCUSSED
Expand Down
2 changes: 2 additions & 0 deletions eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def load_checkpoint(model_path):
if os.path.exists(config_path):
with open(config_path) as f:
config_dict = json.loads(os.path.expandvars(f.read()))
print(config_path)
print(config_dict)
# drop data to prevent validation issues
config_dict["data"] = {}
# drop inference to prevent validation issues
Expand Down
21 changes: 17 additions & 4 deletions eole/predict/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from math import exp
import codecs

from eole.transforms import TransformPipe
from eole.transforms import TransformPipe, AVAILABLE_TRANSFORMS
from eole.constants import DefaultTokens
from eole.predict.prediction import PredictionBuilder
from eole.utils.misc import set_random_seed, report_matrix, sequence_mask
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
return_gold_log_probs=False,
add_estimator=False,
optional_eos=[],
id_tokenization=False,
):
self.model = model
self.vocabs = vocabs
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(

self.return_gold_log_probs = return_gold_log_probs
self.add_estimator = add_estimator
self.id_tokenization = id_tokenization

@classmethod
def from_config(
Expand Down Expand Up @@ -204,6 +206,12 @@ def from_config(
"""
# TODO: maybe add dynamic part

id_tokenization = False
if len(config.transforms) > 0:
tail_transform_cls = AVAILABLE_TRANSFORMS.get(config.transforms[-1], None)
if getattr(tail_transform_cls, "output_type", None) == "ids":
id_tokenization = True

return cls(
model,
vocabs,
Expand Down Expand Up @@ -238,6 +246,7 @@ def from_config(
with_score=config.with_score,
add_estimator=model_config.add_estimator,
optional_eos=config.optional_eos,
id_tokenization=id_tokenization,
)

def _log(self, msg):
Expand Down Expand Up @@ -296,6 +305,7 @@ def _predict(
self.replace_unk,
self.phrase_table,
self._tgt_eos_idx,
self.id_tokenization,
)

# Statistics
Expand Down Expand Up @@ -384,9 +394,12 @@ def _process_bucket(bucket_predictions):
bucket_gold_score += trans.gold_score
bucket_gold_words += len(trans.gold_sent) + 1

n_best_preds = [
" ".join(pred) for pred in trans.pred_sents[: self.n_best]
]
if self.id_tokenization:
n_best_preds = trans.pred_sents[: self.n_best]
else:
n_best_preds = [
" ".join(pred) for pred in trans.pred_sents[: self.n_best]
]

if self.report_align:
align_pharaohs = [
Expand Down
40 changes: 28 additions & 12 deletions eole/predict/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ class PredictionBuilder(object):
"""

def __init__(
self, vocabs, n_best=1, replace_unk=False, phrase_table="", tgt_eos_idx=None
self,
vocabs,
n_best=1,
replace_unk=False,
phrase_table="",
tgt_eos_idx=None,
id_tokenization=False,
):
self.vocabs = vocabs
self.n_best = n_best
self.replace_unk = replace_unk
self.phrase_table_dict = {}
self.tgt_eos_idx = tgt_eos_idx # List of IDs here
self.id_tokenization = id_tokenization
if phrase_table != "" and os.path.exists(phrase_table):
with open(phrase_table) as phrase_table_fd:
for line in phrase_table_fd:
Expand All @@ -39,7 +46,10 @@ def _build_target_tokens(self, src, srclen, pred, attn, voc, dyn_voc):
pred_list = pred.tolist()
if pred_list[-1] in self.tgt_eos_idx:
pred_list = pred_list[:-1]
if dyn_voc is None:
if self.id_tokenization:
# TODO assert dyn_voc is not compatible with id_tokenization
tokens = pred_list
elif dyn_voc is None:
tokens = [voc[tok] for tok in pred_list]
else:
tokens = [
Expand All @@ -49,15 +59,19 @@ def _build_target_tokens(self, src, srclen, pred, attn, voc, dyn_voc):
for tok in pred_list
]

if self.replace_unk and attn is not None and src is not None:
for i in range(len(tokens)):
if tokens[i] == DefaultTokens.UNK:
_, max_index = attn[i][:srclen].max(0)
src_tok = self.vocabs["src"].ids_to_tokens[src[max_index.item()]]
tokens[i] = src_tok
if self.phrase_table_dict:
if src_tok in self.phrase_table_dict:
tokens[i] = self.phrase_table_dict[src_tok]
# TODO: either support this properly or remove?
if not self.id_tokenization:
if self.replace_unk and attn is not None and src is not None:
for i in range(len(tokens)):
if tokens[i] == DefaultTokens.UNK:
_, max_index = attn[i][:srclen].max(0)
src_tok = self.vocabs["src"].ids_to_tokens[
src[max_index.item()]
]
tokens[i] = src_tok
if self.phrase_table_dict:
if src_tok in self.phrase_table_dict:
tokens[i] = self.phrase_table_dict[src_tok]
return tokens

def from_batch(self, prediction_batch):
Expand Down Expand Up @@ -203,7 +217,9 @@ def log(self, sent_number, src_raw=""):
best_pred = self.pred_sents[0]
best_score = self.pred_scores[0]
best_estim = self.estim[0]
pred_sent = " ".join(best_pred)
pred_sent = " ".join(
[str(x) for x in best_pred]
) # this will display IDs for id_tokenize case
msg.append("PRED {}: {}\n".format(sent_number, pred_sent))
msg.append("PRED SCORE: {:.4f}\n".format(best_score))
msg.append("ESTIM SCORE: {:.4f}\n".format(best_estim))
Expand Down
4 changes: 2 additions & 2 deletions eole/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_transforms_cls(transform_names):
def register_transform(name):
"""Transform register that can be used to add new transform class."""

def register_transfrom_cls(cls):
def register_transform_cls(cls):
if name in AVAILABLE_TRANSFORMS:
raise ValueError("Cannot register duplicate transform ({})".format(name))
if not issubclass(cls, Transform):
Expand All @@ -47,7 +47,7 @@ def register_transfrom_cls(cls):
cls.name = name
return cls

return register_transfrom_cls
return register_transform_cls


# Auto import python files in this directory
Expand Down
Loading

0 comments on commit 09dcaba

Please sign in to comment.