Skip to content

Commit

Permalink
get rid of unnecessary decoder_start_table
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Sep 20, 2024
1 parent f5378cc commit 0c0f60b
Showing 1 changed file with 1 addition and 37 deletions.
38 changes: 1 addition & 37 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,6 @@
"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig,
}

decoder_start_table = {
"LlamaForCausalLM": "<s>",
"MistralForCausalLM": "<s>",
"MixtralForCausalLM": "<s>",
"PhiForCausalLM": "",
"Phi3ForCausalLM": "<s>",
"GPT2LMHeadModel": "</s>",
"XLMRobertaXLForMaskedLM": "<s>",
}


class Tokenizer:
def __init__(self, model_path: str):
Expand Down Expand Up @@ -591,7 +581,6 @@ def run(cls, args):
"n_positions": 0,
}
left_pad = True
# eos_token = None
optional_eos = []
mapped_tokens = []

Expand Down Expand Up @@ -951,11 +940,6 @@ def get_weight(checkpoint, tensor_name):
data["added_tokens_decoder"][str(index)]["content"]
for index in eos_token_id[1:]
]
# eos_token = optional_eos[0]
# elif isinstance(eos_token_id, int):
# eos_token = data["added_tokens_decoder"][str(eos_token_id)][
# "content"
# ]
# Automatically convert added_tokens into mapped_tokens
mapped_tokens = [
(
Expand Down Expand Up @@ -1008,16 +992,6 @@ def get_weight(checkpoint, tensor_name):
vocab.extend(newtokens)
for tok in data["added_tokens"]:
vocab[tok["id"]] = tok["content"]
# if "<|startoftext|>" in vocab:
# index = vocab.index("<|startoftext|>")
# vocab[index] = DefaultTokens.BOS
# if eos_token is not None:
# if eos_token in vocab and "</s>" not in vocab:
# index = vocab.index(eos_token)
# vocab[index] = DefaultTokens.EOS
# if "<0x00>" in vocab:
# index = vocab.index("<0x00>")
# vocab[index] = DefaultTokens.PAD
src_vocab = pyonmttok.build_vocab_from_tokens(
vocab,
)
Expand All @@ -1042,16 +1016,6 @@ def get_weight(checkpoint, tensor_name):
vocab.append(DefaultTokens.VOCAB_PAD + str(i))
for tok in data["added_tokens"]:
vocab[tok["id"]] = tok["content"]
# if "<|startoftext|>" in vocab:
# index = vocab.index("<|startoftext|>")
# vocab[index] = DefaultTokens.BOS
# if "<|begin_of_text|>" in vocab:
# index = vocab.index("<|begin_of_text|>")
# vocab[index] = DefaultTokens.BOS
# if eos_token is not None:
# if eos_token in vocab and "</s>" not in vocab:
# index = vocab.index(eos_token)
# vocab[index] = DefaultTokens.EOS
src_vocab = pyonmttok.build_vocab_from_tokens(vocab)

tokenizer_basename = "bpe.model"
Expand All @@ -1066,7 +1030,7 @@ def get_weight(checkpoint, tensor_name):
vocabs["src"] = src_vocab
vocabs["tgt"] = src_vocab
if add_bos_token:
vocabs["decoder_start_token"] = decoder_start_table[arch]
vocabs["decoder_start_token"] = vocabs["specials"]["bos_token"]
else:
vocabs["decoder_start_token"] = ""
vocab_dict = vocabs_to_dict(vocabs)
Expand Down

0 comments on commit 0c0f60b

Please sign in to comment.