Skip to content

Commit

Permalink
Rework handling of special tokens (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Sep 20, 2024
1 parent 0bb7e80 commit b849e18
Show file tree
Hide file tree
Showing 25 changed files with 590 additions and 138 deletions.
50 changes: 34 additions & 16 deletions docs/docusaurus_tsx/docs/FAQ/special_tokens.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
# What special tokens are used?

In the v2, special tokens were different for SEQ2SEQ and LM:
LM was BOS, PAD, EOS with IDs (0, 1, 2) and the first vocab token started at id=3
SEQ2SEQ was UNK, PAD, BOS, EOS with IDs (0, 1, 2, 3) and first vocab token started at id=4

In v3 we changed this behavior to align things:
group.add(
"--default_specials",
"-default_specilas",
nargs="+",
type=str,
default=[
DefaultTokens.UNK,
DefaultTokens.PAD,
DefaultTokens.BOS,
DefaultTokens.EOS,
])
There are 4 main special tokens:
- BOS for "beginning of sentence";
- PAD for "padding";
- EOS for "end of sentence";
- UNK for "unknown".

## Special tokens actually used

Depending on the context, these tokens can take various values:

1. Default behaviour, training from scratch

Some default values are defined as [constants](https://github.com/eole-nlp/eole/blob/ff39275c50d12951963008da11d029940b590713/eole/constants.py#L8) for the project:
```python
class DefaultTokens(object):
PAD = "<blank>"
BOS = "<s>"
EOS = "</s>"
UNK = "<unk>"
```

2. Retrieving a pretrained model from HF

The special tokens will be retrieved and configured from the `special_tokens_map.json` configuration file from the HF model files.

3. Custom behaviour

In any case, these tokens can be overriden via the ad-hoc configuration settings:
- `bos_token`
- `pad_token`
- `eos_token`
- `unk_token`

## Special tokens behaviour in Eole

When we train a SEQ2SEQ model we use:
SRC: srctok1 srctok2 srctok3 .... srctokn
Expand Down
111 changes: 53 additions & 58 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,6 @@
"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig,
}

decoder_start_table = {
"LlamaForCausalLM": "<s>",
"MistralForCausalLM": "<s>",
"MixtralForCausalLM": "<s>",
"PhiForCausalLM": "",
"Phi3ForCausalLM": "<s>",
"GPT2LMHeadModel": "</s>",
"XLMRobertaXLForMaskedLM": "<s>",
}

specials_table = {
"LlamaForCausalLM": ["<unk>", "<s>", "</s>"],
"MistralForCausalLM": ["<unk>", "<s>", "</s>"],
"MixtralForCausalLM": ["<unk>", "<s>", "</s>"],
"PhiForCausalLM": ["<unk>", "<s>", "</s>"],
"Phi3ForCausalLM": ["<unk>", "<s>", "</s>"],
"GPT2LMHeadModel": ["<unk>", "<s>", "</s>"],
"XLMRobertaXLForMaskedLM": ["<s>", "<blank>", "</s>", "<unk>"],
}


class Tokenizer:
def __init__(self, model_path: str):
Expand Down Expand Up @@ -313,6 +293,12 @@ def run(cls, args):
)
else:
tokenizer_config_json = None
if os.path.exists(os.path.join(args.model_dir, "special_tokens_map.json")):
tokenizer_config_json = os.path.join(
args.model_dir, "special_tokens_map.json"
)
else:
tokenizer_config_json = None
if os.path.exists(os.path.join(args.model_dir, "generation_config.json")):
generation_config_json = os.path.join(
args.model_dir, "generation_config.json"
Expand Down Expand Up @@ -415,6 +401,22 @@ def run(cls, args):
raise huggingface_hub.utils.EntryNotFoundError(
"No valid model files found"
)
try:
try:
special_tokens_json = huggingface_hub.hf_hub_download(
repo_id=args.model_dir,
filename="special_tokens_map.json",
token=args.token,
)
except huggingface_hub.utils.EntryNotFoundError:
raise huggingface_hub.utils.EntryNotFoundError(
"Something went wrong the repo does not contain"
"any special_tokens_map.json file"
)
except Exception as e:
if isinstance(e, huggingface_hub.utils.EntryNotFoundError):
special_tokens_json = None
print(e)

with open(config_path, encoding="utf-8") as fconfig:
config = json.load(fconfig)
Expand Down Expand Up @@ -584,7 +586,6 @@ def run(cls, args):
"n_positions": 0,
}
left_pad = True
eos_token = None
optional_eos = []
mapped_tokens = []
gpt2_pretok = False
Expand Down Expand Up @@ -947,18 +948,6 @@ def get_weight(checkpoint, tensor_name):
for index in eos_token_id
]
optional_eos = eos_tokens[1:]
eos_token = eos_tokens[0]
elif isinstance(eos_token_id, int):
if "eos_token" in data.keys():
if isinstance(data["eos_token"], dict):
# Llama2 style
eos_token = data["eos_token"]["content"]
elif isinstance(data["eos_token"], str):
eos_token = data["eos_token"]
elif "added_tokens_decoder" in data.keys():
eos_token = data["added_tokens_decoder"][str(eos_token_id)][
"content"
]
# Automatically convert added_tokens into mapped_tokens
if "added_tokens_decoder" in data.keys():
mapped_tokens = [
Expand All @@ -973,6 +962,29 @@ def get_weight(checkpoint, tensor_name):
else:
add_bos_token = True

vocabs = {"specials": {}}

if special_tokens_json is not None:
with open(special_tokens_json, encoding="utf-8") as f:
special_tokens_map = json.load(f)
for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]:
token = special_tokens_map.get(token_name, None)
if isinstance(token, list):
vocabs["specials"][token_name] = token[0]
elif isinstance(token, str):
vocabs["specials"][token_name] = token
elif isinstance(token, dict):
vocabs["specials"][token_name] = token["content"]
elif tokenizer_json is not None:
with open(tokenizer_json, encoding="utf-8") as f:
data = json.load(f)
vocab = {v: k for k, v in data["model"]["vocab"].items()}
for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]:
if f"{token_name}_id" in config.keys():
vocabs["specials"][token_name] = vocab[
config[f"{token_name}_id"]
]

if generation_config_json is not None:
with open(generation_config_json, encoding="utf-8") as f:
data = json.load(f)
Expand All @@ -982,8 +994,6 @@ def get_weight(checkpoint, tensor_name):
for key in keys:
if key in data.keys():
generation_config_dict[key] = data[key]

vocabs = {}
if (
tokenizer_model is not None
): # sentencepiece mode (might be good to check it's a SP model)
Expand All @@ -1003,26 +1013,19 @@ def get_weight(checkpoint, tensor_name):
vocab.extend(newtokens)
for tok in data["added_tokens"]:
vocab[tok["id"]] = tok["content"]
if "<|startoftext|>" in vocab:
index = vocab.index("<|startoftext|>")
vocab[index] = DefaultTokens.BOS
if eos_token is not None:
if eos_token in vocab and "</s>" not in vocab:
index = vocab.index(eos_token)
vocab[index] = DefaultTokens.EOS
if "<0x00>" in vocab:
index = vocab.index("<0x00>")
vocab[index] = DefaultTokens.PAD
src_vocab = pyonmttok.build_vocab_from_tokens(
vocab,
special_tokens=specials_table[arch],
)
else: # # BPE mode - we leverage the HF tokenizer.json info
src_subword_type = "bpe"
with open(tokenizer_json, encoding="utf-8") as f:
data = json.load(f)
# gpt2_pretok
pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}])
pre_tokenizer = data.get("pre_tokenizer", None)
pretokenizers = pre_tokenizer.get("pretokenizers", None)
if pretokenizers is None:
pretokenizers = [pre_tokenizer]
for pretokenizer in pretokenizers:
if pretokenizer.get("type", None) == "ByteLevel":
gpt2_pretok = True
Expand All @@ -1031,23 +1034,14 @@ def get_weight(checkpoint, tensor_name):
# "Ā" is '\x00' in unicode (cf tokenize.py gpt2 mapping)
for tok in data["model"]["vocab"]
]
if DefaultTokens.PAD in vocab:
vocabs["specials"]["pad_token"] = DefaultTokens.PAD
voc_size = len(vocab)
if vocab_size > voc_size:
for i in range(vocab_size - voc_size):
vocab.append(DefaultTokens.VOCAB_PAD + str(i))
for tok in data["added_tokens"]:
vocab[tok["id"]] = tok["content"]
if "<|startoftext|>" in vocab:
index = vocab.index("<|startoftext|>")
vocab[index] = DefaultTokens.BOS
if "<|begin_of_text|>" in vocab:
index = vocab.index("<|begin_of_text|>")
vocab[index] = DefaultTokens.BOS
if eos_token is not None:
if eos_token in vocab and "</s>" not in vocab:
index = vocab.index(eos_token)
vocab[index] = DefaultTokens.EOS

src_vocab = pyonmttok.build_vocab_from_tokens(vocab)

tokenizer_basename = "bpe.model"
Expand All @@ -1062,7 +1056,7 @@ def get_weight(checkpoint, tensor_name):
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
if add_bos_token:
vocabs["decoder_start_token"] = decoder_start_table[arch]
vocabs["decoder_start_token"] = vocabs["specials"]["bos_token"]
else:
vocabs["decoder_start_token"] = ""
vocab_dict = vocabs_to_dict(vocabs)
Expand All @@ -1089,6 +1083,7 @@ def get_weight(checkpoint, tensor_name):
tgt_vocab_size=vocab_size,
vocab_size_multiple=8,
decoder_start_token=vocabs["decoder_start_token"],
**vocabs["specials"],
transforms=["onmt_tokenize", "filtertoolong"],
transforms_configs={
"filtertoolong": {"src_seq_length": 512, "tgt_seq_length": 512},
Expand Down
3 changes: 2 additions & 1 deletion eole/bin/tools/LM_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def run(cls, args):
)

vocabs, model, model_opt = config.model.model_class.load_test_model(config)
padding_idx = vocabs["tgt"][DefaultTokens.PAD]
pad_token = vocabs["specials"].get("pad_token", DefaultTokens.PAD)
padding_idx = vocabs["tgt"][pad_token]
criterion = torch.nn.CrossEntropyLoss(
ignore_index=padding_idx, reduction="none"
)
Expand Down
20 changes: 11 additions & 9 deletions eole/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ class BaseVocabConfig(Config):
description="Default decoder start token. For most models it is <s> = BOS. "
"Some fairseq models require </s>.",
)
default_specials: list = Field(
default=[
constants.DefaultTokens.UNK,
constants.DefaultTokens.PAD,
constants.DefaultTokens.BOS,
constants.DefaultTokens.EOS,
],
description="Default specials used for vocab initialization. "
"UNK, PAD, BOS, EOS will take IDs 0, 1, 2, 3.",
bos_token: str | None = Field(
default=constants.DefaultTokens.BOS,
)
eos_token: str | None = Field(
default=constants.DefaultTokens.EOS,
)
unk_token: str | None = Field(
default=constants.DefaultTokens.UNK,
)
pad_token: str | None = Field(
default=constants.DefaultTokens.PAD,
)
# pre trained embeddings stuff, might be put elsewhere
both_embeddings: str | None = Field(
Expand Down
12 changes: 11 additions & 1 deletion eole/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def __init__(self, config, model_type=None):
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
vocabs["decoder_start_token"] = "<s>"
# TODO: this should be loaded from model config
vocabs["specials"] = {
"bos_token": DefaultTokens.BOS,
"pad_token": DefaultTokens.PAD,
"eos_token": DefaultTokens.EOS,
"unk_token": DefaultTokens.UNK,
}
self.vocabs = vocabs
# Build transform pipe
transforms = make_transforms(config, self.transforms_cls, self.vocabs)
Expand All @@ -290,7 +297,10 @@ def predict_batch(self, batch, config):
_input_tokens = [
self.vocabs["src"].lookup_index(id)
for id in start_ids
if id != self.vocabs["src"].lookup_token(DefaultTokens.PAD)
if id
!= self.vocabs["src"].lookup_token(
self.vocabs["specials"].get("pad_token", DefaultTokens.PAD)
)
]
input_tokens.append(_input_tokens)
if self.model_type == ModelType.DECODER:
Expand Down
18 changes: 14 additions & 4 deletions eole/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def build_vocab(config, specials):
'decoder_start_token': DefaultTokens.BOS
}
"""
vocabs = {}
vocabs["specials"] = {
"bos_token": config.bos_token,
"pad_token": config.pad_token,
"eos_token": config.eos_token,
"unk_token": config.unk_token,
}

def _pad_vocab_to_multiple(vocab, multiple):
vocab_size = len(vocab)
Expand All @@ -26,8 +33,8 @@ def _pad_vocab_to_multiple(vocab, multiple):
vocab.add_token(DefaultTokens.VOCAB_PAD + str(i))
return vocab

default_specials = config.default_specials
vocabs = {}
default_specials = list(vocabs["specials"].values())

src_vocab = _read_vocab_file(config.src_vocab, config.src_words_min_frequency)

src_specials = [
Expand All @@ -45,7 +52,7 @@ def _pad_vocab_to_multiple(vocab, multiple):
src_vocab = pyonmttok.build_vocab_from_tokens(
src_vocab, maximum_size=config.src_vocab_size, special_tokens=src_specials
)
src_vocab.default_id = src_vocab[DefaultTokens.UNK]
src_vocab.default_id = src_vocab[config.unk_token]
if src_vocab.default_id >= len(src_vocab):
src_vocab.default_id = (
0 # patch that assigns OOV to id=0 when UNK does not exist
Expand Down Expand Up @@ -80,7 +87,6 @@ def _pad_vocab_to_multiple(vocab, multiple):
vocabs["tgt"] = tgt_vocab

vocabs["decoder_start_token"] = config.decoder_start_token

return vocabs


Expand Down Expand Up @@ -126,6 +132,8 @@ def vocabs_to_dict(vocabs):
vocabs_dict["decoder_start_token"] = vocabs["decoder_start_token"]
else:
vocabs_dict["decoder_start_token"] = DefaultTokens.BOS
if "specials" in vocabs.keys():
vocabs_dict["specials"] = vocabs["specials"]
return vocabs_dict


Expand All @@ -148,5 +156,7 @@ def dict_to_vocabs(vocabs_dict):
vocabs["tgt"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["tgt"])
if vocabs["tgt"].default_id >= len(vocabs["src"]):
vocabs["tgt"].default_id = 0 # patch that assigns OOV to id=0
if "specials" in vocabs_dict.keys():
vocabs["specials"] = vocabs_dict["specials"]

return vocabs
Loading

0 comments on commit b849e18

Please sign in to comment.