|
| 1 | +#!/usr/bin/python3 |
| 2 | +from __future__ import print_function |
| 3 | + |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import torch |
| 7 | +import torch.backends.cudnn as cudnn |
| 8 | +import argparse |
| 9 | +import cv2 |
| 10 | +import numpy as np |
| 11 | +from collections import OrderedDict |
| 12 | + |
| 13 | +sys.path.append(os.getcwd() + '/../../src') |
| 14 | + |
| 15 | +from config import cfg |
| 16 | +from prior_box import PriorBox |
| 17 | +from nms import nms |
| 18 | +from utils import decode |
| 19 | +from timer import Timer |
| 20 | +from yufacedetectnet import YuFaceDetectNet |
| 21 | + |
| 22 | +parser = argparse.ArgumentParser(description='Face and Landmark Detection') |
| 23 | + |
| 24 | +parser.add_argument('-m', '--trained_model', default='weights/yunet_final.pth', |
| 25 | + type=str, help='Trained state_dict file path to open') |
| 26 | +parser.add_argument('-d', '--image_dim', default=320, |
| 27 | + type=int, help='Input image width') |
| 28 | +parser.add_argument('-o', '--output', default='onnx/facedetectcnn.onnx', |
| 29 | + type=str, help='The output ONNX file, trained parameters inside') |
| 30 | +args = parser.parse_args() |
| 31 | + |
| 32 | +def check_keys(model, pretrained_state_dict): |
| 33 | + ckpt_keys = set(pretrained_state_dict.keys()) |
| 34 | + model_keys = set(model.state_dict().keys()) |
| 35 | + used_pretrained_keys = model_keys & ckpt_keys |
| 36 | + unused_pretrained_keys = ckpt_keys - model_keys |
| 37 | + missing_keys = model_keys - ckpt_keys |
| 38 | + print('Missing keys:{}'.format(len(missing_keys))) |
| 39 | + print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) |
| 40 | + print('Used keys:{}'.format(len(used_pretrained_keys))) |
| 41 | + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' |
| 42 | + return True |
| 43 | + |
| 44 | +def remove_prefix(state_dict, prefix): |
| 45 | + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' |
| 46 | + print('remove prefix \'{}\''.format(prefix)) |
| 47 | + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x |
| 48 | + return {f(key): value for key, value in state_dict.items()} |
| 49 | + |
| 50 | +def load_model(model, pretrained_path, load_to_cpu): |
| 51 | + print('Loading pretrained model from {}'.format(pretrained_path)) |
| 52 | + if load_to_cpu: |
| 53 | + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) |
| 54 | + else: |
| 55 | + device = torch.cuda.current_device() |
| 56 | + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) |
| 57 | + if "state_dict" in pretrained_dict.keys(): |
| 58 | + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') |
| 59 | + else: |
| 60 | + pretrained_dict = remove_prefix(pretrained_dict, 'module.') |
| 61 | + check_keys(model, pretrained_dict) |
| 62 | + model.load_state_dict(pretrained_dict, strict=False) |
| 63 | + return model |
| 64 | + |
| 65 | + |
| 66 | +if __name__ == '__main__': |
| 67 | + |
| 68 | + torch.set_grad_enabled(False) |
| 69 | + |
| 70 | + # net and model |
| 71 | + net = YuFaceDetectNet(phase='test', size=None ) # initialize detector |
| 72 | + net = load_model(net, args.trained_model, True) |
| 73 | + net.eval() |
| 74 | + |
| 75 | + print('Finished loading model!') |
| 76 | + |
| 77 | + height = 0.75 * args.image_dim |
| 78 | + img_raw = np.zeros((args.image_dim, int(height), 3), np.uint8) |
| 79 | + img = np.float32(img_raw) |
| 80 | + |
| 81 | + img = img.transpose(2, 0, 1) |
| 82 | + img = torch.from_numpy(img).unsqueeze(0) |
| 83 | + img = img.to(torch.device('cpu')) |
| 84 | + torch.onnx.export(net, img, args.output) |
| 85 | + print('Finished exporing model to ' + args.output) |
0 commit comments