Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jul 7, 2023
1 parent a8f0832 commit 8a44ab8
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions tools/convert_T5.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,26 @@ def __init__(self, model_path: str):
):
onmt_safetensor[
"encoder.transformer." + str(i) + ".self_attn.linear_query.weight"
] = (checkpoint[
"encoder.block." + str(i) + ".layer.0.SelfAttention.q.weight"
] / (dimperhead ** -0.5)).to(
] = (
checkpoint[
"encoder.block." + str(i) + ".layer.0.SelfAttention.q.weight"
]
/ (dimperhead**-0.5)
).to(
torch.float16
)
)
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
+ ".self_attn.linear_query.weight"
] = (checkpoint[
"decoder.block." + str(i) + ".layer.0.SelfAttention.q.weight"
] / (dimperhead ** -0.5)).to(
] = (
checkpoint[
"decoder.block." + str(i) + ".layer.0.SelfAttention.q.weight"
]
/ (dimperhead**-0.5)
).to(
torch.float16
)
)
onmt_safetensor[
"encoder.transformer." + str(i) + ".self_attn.linear_keys.weight"
] = checkpoint[
Expand Down Expand Up @@ -226,11 +232,14 @@ def __init__(self, model_path: str):
"decoder.transformer_layers."
+ str(i)
+ ".context_attn.linear_query.weight"
] = (checkpoint[
"decoder.block." + str(i) + ".layer.1.EncDecAttention.q.weight"
] / (dimperhead ** -0.5)).to(
] = (
checkpoint[
"decoder.block." + str(i) + ".layer.1.EncDecAttention.q.weight"
]
/ (dimperhead**-0.5)
).to(
torch.float16
)
)
onmt_safetensor[
"decoder.transformer_layers."
+ str(i)
Expand Down

0 comments on commit 8a44ab8

Please sign in to comment.