-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert.py
91 lines (74 loc) · 2.56 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch import Tensor
from safetensors.torch import save_file, load_file
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", type=str,
default="model.ckpt", help="path to model")
parser.add_argument("-p", "--precision", default="fp32",
help="precision fp32(full)/fp16/bf16")
parser.add_argument("-t", "--type", type=str, default="full",
help="convert types full/ema-only/no-ema")
parser.add_argument("-st", "--safe-tensors", action="store_true",
default=False, help="use safetensors model format")
cmds = parser.parse_args()
def conv_fp16(t: Tensor):
if not isinstance(t, Tensor):
return t
return t.half()
def conv_bf16(t: Tensor):
if not isinstance(t, Tensor):
return t
return t.bfloat16()
def conv_full(t):
return t
_g_precision_func = {
"full": conv_full,
"fp32": conv_full,
"half": conv_fp16,
"fp16": conv_fp16,
"bf16": conv_bf16,
}
def convert(path: str, conv_type: str):
ok = {} # {"state_dict": {}}
_hf = _g_precision_func[cmds.precision]
if path.endswith(".safetensors"):
m = load_file(path, device="cpu")
else:
m = torch.load(path, map_location="cpu")
state_dict = m["state_dict"] if "state_dict" in m else m
if conv_type == "ema-only" or conv_type == "prune":
for k in state_dict:
ema_k = "___"
try:
ema_k = "model_ema." + k[6:].replace(".", "")
except:
pass
if ema_k in state_dict:
ok[k] = _hf(state_dict[ema_k])
print("ema: " + ema_k + " > " + k)
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
ok[k] = _hf(state_dict[k])
print(k)
else:
print("skipped: " + k)
elif conv_type == "no-ema":
for k, v in state_dict.items():
if "model_ema" not in k:
ok[k] = _hf(v)
else:
for k, v in state_dict.items():
ok[k] = _hf(v)
return ok
def main():
model_name = ".".join(cmds.file.split(".")[:-1])
converted = convert(cmds.file, cmds.type)
save_name = f"{model_name}-{cmds.type}-{cmds.precision}"
print("convert ok, saving model")
if cmds.safe_tensors:
save_file(converted, save_name + ".safetensors")
else:
torch.save({"state_dict": converted}, save_name + ".ckpt")
print("convert finish.")
if __name__ == "__main__":
main()