Skip to content

Commit

Permalink
patch convert_HF for better back-compatibility, grab compute_dtype fr…
Browse files Browse the repository at this point in the history
…om lora weights
  • Loading branch information
francoishernandez committed Sep 19, 2024
1 parent b5df05c commit 689fece
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
39 changes: 24 additions & 15 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def run(cls, args):
eos_token = None
optional_eos = []
mapped_tokens = []
gpt2_pretok = False

# ALL THESE IF SHOULD BE HANDLED IN MAPPINGS
if arch == "PhiForCausalLM":
Expand Down Expand Up @@ -940,23 +941,32 @@ def get_weight(checkpoint, tensor_name):
# Not sure if we could do much cleaner to retrieve optional eos tokens
eos_token_id = config.get("eos_token_id", None)
if isinstance(eos_token_id, list):
optional_eos = [
data["added_tokens_decoder"][str(index)]["content"]
for index in eos_token_id[1:]
]
if "added_tokens_decoder" in data.keys():
optional_eos = [
data["added_tokens_decoder"][str(index)]["content"]
for index in eos_token_id[1:]
]
eos_token = optional_eos[0]
elif isinstance(eos_token_id, int):
eos_token = data["added_tokens_decoder"][str(eos_token_id)][
"content"
]
if "eos_token" in data.keys():
if isinstance(data["eos_token"], dict):
# Llama2 style
eos_token = data["eos_token"]["content"]
elif isinstance(data["eos_token"], str):
eos_token = data["eos_token"]
elif "added_tokens_decoder" in data.keys():
eos_token = data["added_tokens_decoder"][str(eos_token_id)][
"content"
]
# Automatically convert added_tokens into mapped_tokens
mapped_tokens = [
(
token["content"],
re.sub(r"<\|([^|]*)\|>", "\uff5f\\1\uff60", token["content"]),
)
for token in data["added_tokens_decoder"].values()
]
if "added_tokens_decoder" in data.keys():
mapped_tokens = [
(
token["content"],
re.sub(r"<\|([^|]*)\|>", "\uff5f\\1\uff60", token["content"]),
)
for token in data["added_tokens_decoder"].values()
]
else:
add_bos_token = True

Expand Down Expand Up @@ -1009,7 +1019,6 @@ def get_weight(checkpoint, tensor_name):
with open(tokenizer_json, encoding="utf-8") as f:
data = json.load(f)
# gpt2_pretok
gpt2_pretok = False
pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}])
for pretokenizer in pretokenizers:
if pretokenizer.get("type", None) == "ByteLevel":
Expand Down
8 changes: 8 additions & 0 deletions eole/bin/model/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def add_args(cls, parser):
@classmethod
def run(cls, args):
init_logger()
config_path = os.path.join(args.base_model, "config.json")
with open(config_path) as f:
config = json.load(f)
inference_config = config.get("inference", None)
base_checkpoint = load_checkpoint(args.base_model)
lora_checkpoint = load_checkpoint(args.lora_weights)
vocabs = dict_to_vocabs(lora_checkpoint["vocab"])
Expand Down Expand Up @@ -84,6 +88,8 @@ def run(cls, args):
optim = None
model_state_dict = model.state_dict()
new_config = base_checkpoint["config"]
# use compute_dtype from lora finetuning
new_config.training.compute_dtype = config.training.compute_dtype
elif args.action == "concat":
model.half() # We keep FP16 for all
optim = lora_checkpoint["optim"]
Expand All @@ -101,6 +107,8 @@ def run(cls, args):
json.dump(vocab_dict, f, indent=2, ensure_ascii=False)
# save config
config_dict = recursive_model_fields_set(new_config)
if inference_config is not None:
config_dict["inference"] = inference_config
with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
shards = glob.glob(os.path.join(args.base_model, "model.*.safetensors"))
Expand Down

0 comments on commit 689fece

Please sign in to comment.