Skip to content

Commit

Permalink
deduce share_decoder_embeddings from HF tie_word_embeddings flag (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Oct 3, 2024
1 parent 43593eb commit cdab121
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def run(cls, args):
quant_layers = []
params = ["weight", "bias"]

share_decoder_embeddings = config.get("tie_word_embeddings", False)

add_qkvbias = False
add_ffnbias = False
shared_layer_norm = False
Expand All @@ -589,7 +591,6 @@ def run(cls, args):
optional_eos = []
mapped_tokens = []
gpt2_pretok = False
share_decoder_embeddings = False
generator_bias = False

# ALL THESE IF SHOULD BE HANDLED IN MAPPINGS
Expand Down Expand Up @@ -689,6 +690,8 @@ def get_weight(checkpoint, tensor_name):
"encoder.layer_norm.bias",
"generator.weight",
]
if share_decoder_embeddings:
targetlist.remove("generator.weight")
for target in targetlist:
if target in key_maps[arch].keys():
source = key_maps[arch][target]
Expand All @@ -701,19 +704,10 @@ def get_weight(checkpoint, tensor_name):
w = get_weight(checkpoint, source)
if w is not None:
eole_safetensor[target] = w
elif target == "generator.weight":
# lm_head is not in HF safetensors -> share from embeddings matrix
share_decoder_embeddings = True

if target == "generator.bias":
generator_bias = True

if torch.equal(
eole_safetensor.get("generator.weight", None),
eole_safetensor["tgt_emb.embeddings.weight"],
):
share_decoder_embeddings = True

if wmap_path:
weightmap = wmap["weight_map"]
ckpt_list = []
Expand Down

0 comments on commit cdab121

Please sign in to comment.