Skip to content

Commit 9ac244c

Browse files
authoredMay 17, 2020
Merge pull request #13 from vladkol/vladkol/onnx
Added ONNX export
2 parents c5aecdc + 30437f8 commit 9ac244c

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed
 

‎tasks/task1/exportonnx.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/python3
2+
from __future__ import print_function
3+
4+
import os
5+
import sys
6+
import torch
7+
import torch.backends.cudnn as cudnn
8+
import argparse
9+
import cv2
10+
import numpy as np
11+
from collections import OrderedDict
12+
13+
sys.path.append(os.getcwd() + '/../../src')
14+
15+
from config import cfg
16+
from prior_box import PriorBox
17+
from nms import nms
18+
from utils import decode
19+
from timer import Timer
20+
from yufacedetectnet import YuFaceDetectNet
21+
22+
parser = argparse.ArgumentParser(description='Face and Landmark Detection')
23+
24+
parser.add_argument('-m', '--trained_model', default='weights/yunet_final.pth',
25+
type=str, help='Trained state_dict file path to open')
26+
parser.add_argument('-d', '--image_dim', default=320,
27+
type=int, help='Input image width')
28+
parser.add_argument('-o', '--output', default='onnx/facedetectcnn.onnx',
29+
type=str, help='The output ONNX file, trained parameters inside')
30+
args = parser.parse_args()
31+
32+
def check_keys(model, pretrained_state_dict):
33+
ckpt_keys = set(pretrained_state_dict.keys())
34+
model_keys = set(model.state_dict().keys())
35+
used_pretrained_keys = model_keys & ckpt_keys
36+
unused_pretrained_keys = ckpt_keys - model_keys
37+
missing_keys = model_keys - ckpt_keys
38+
print('Missing keys:{}'.format(len(missing_keys)))
39+
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
40+
print('Used keys:{}'.format(len(used_pretrained_keys)))
41+
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
42+
return True
43+
44+
def remove_prefix(state_dict, prefix):
45+
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
46+
print('remove prefix \'{}\''.format(prefix))
47+
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
48+
return {f(key): value for key, value in state_dict.items()}
49+
50+
def load_model(model, pretrained_path, load_to_cpu):
51+
print('Loading pretrained model from {}'.format(pretrained_path))
52+
if load_to_cpu:
53+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
54+
else:
55+
device = torch.cuda.current_device()
56+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
57+
if "state_dict" in pretrained_dict.keys():
58+
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
59+
else:
60+
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
61+
check_keys(model, pretrained_dict)
62+
model.load_state_dict(pretrained_dict, strict=False)
63+
return model
64+
65+
66+
if __name__ == '__main__':
67+
68+
torch.set_grad_enabled(False)
69+
70+
# net and model
71+
net = YuFaceDetectNet(phase='test', size=None ) # initialize detector
72+
net = load_model(net, args.trained_model, True)
73+
net.eval()
74+
75+
print('Finished loading model!')
76+
77+
height = 0.75 * args.image_dim
78+
img_raw = np.zeros((args.image_dim, int(height), 3), np.uint8)
79+
img = np.float32(img_raw)
80+
81+
img = img.transpose(2, 0, 1)
82+
img = torch.from_numpy(img).unsqueeze(0)
83+
img = img.to(torch.device('cpu'))
84+
torch.onnx.export(net, img, args.output)
85+
print('Finished exporing model to ' + args.output)
8.9 MB
Binary file not shown.
8.9 MB
Binary file not shown.

0 commit comments

Comments
 (0)