-
Notifications
You must be signed in to change notification settings - Fork 5
/
export.py
66 lines (52 loc) · 1.53 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import paddle
from GeoTr import GeoTr
def export(args):
model_path = args.model
imgsz = args.imgsz
format = args.format
model = GeoTr()
checkpoint = paddle.load(model_path)
model.set_state_dict(checkpoint["model"])
model.eval()
dirname = os.path.dirname(model_path)
if format == "static" or format == "onnx":
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
],
full_graph=True,
)
paddle.jit.save(model, os.path.join(dirname, "model"))
if format == "onnx":
onnx_path = os.path.join(dirname, "model.onnx")
os.system(
f"paddle2onnx --model_dir {dirname}"
" --model_filename model.pdmodel"
" --params_filename model.pdiparams"
f" --save_file {onnx_path}"
" --opset_version 11"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export model")
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--imgsz", type=int, default=288, help="The size of input image"
)
parser.add_argument(
"--format",
type=str,
default="static",
help="The format of exported model, which can be static or onnx",
)
args = parser.parse_args()
export(args)