-
Notifications
You must be signed in to change notification settings - Fork 836
/
caffe_export.py
85 lines (68 loc) · 2.18 KB
/
caffe_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# encoding: utf-8
"""
@author: xingyu liao
@contact: [email protected]
"""
import argparse
import logging
import sys
import torch
sys.path.append('.')
import pytorch_to_caffe
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
# import some modules added in project like this below
# sys.path.append("projects/PartialReID")
# from partialreid import *
setup_logger(name='fastreid')
logger = logging.getLogger("fastreid.caffe_export")
def setup_cfg(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model")
parser.add_argument(
"--config-file",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--name",
default="baseline",
help="name for converted model"
)
parser.add_argument(
"--output",
default='caffe_model',
help='path to save converted caffe model'
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
cfg = setup_cfg(args)
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
cfg.MODEL.HEADS.POOL_LAYER = "Identity"
cfg.MODEL.BACKBONE.WITH_NL = False
model = build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()
logger.info(model)
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE))
PathManager.mkdirs(args.output)
pytorch_to_caffe.trans_net(model, inputs, args.name)
pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
logger.info(f"Export caffe model in {args.output} sucessfully!")