forked from tornadomeet/mx-rcnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo.py
82 lines (69 loc) · 2.9 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import os
import numpy as np
import cv2
import mxnet as mx
from helper.processing.image_processing import resize, transform
from helper.processing.nms import nms
from rcnn.config import config
from rcnn.detector import Detector
from rcnn.symbol import get_vgg_test
from rcnn.tester import vis_all_detection, save_all_detection
from utils.load_model import load_param
def get_net(prefix, epoch, ctx):
args, auxs, num_class = load_param(prefix, epoch, convert=True, ctx=ctx)
sym = get_vgg_test(num_classes=num_class)
detector = Detector(sym, ctx, args, auxs)
return detector
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
def demo_net(detector, image_name, vis=False):
"""
wrapper for detector
:param detector: Detector
:param image_name: image name
:return: None
"""
config.TEST.HAS_RPN = True
assert os.path.exists(image_name), image_name + ' not found'
im = cv2.imread(image_name)
im_array, im_scale = resize(im, config.SCALES[0], config.MAX_SIZE)
im_array = transform(im_array, config.PIXEL_MEANS)
im_info = np.array([[im_array.shape[2], im_array.shape[3], im_scale]], dtype=np.float32)
scores, boxes = detector.im_detect(im_array, im_info)
all_boxes = [[] for _ in CLASSES]
CONF_THRESH = 0.8
NMS_THRESH = 0.3
for cls in CLASSES:
cls_ind = CLASSES.index(cls)
cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
cls_scores = scores[:, cls_ind]
keep = np.where(cls_scores >= CONF_THRESH)[0]
cls_boxes = cls_boxes[keep, :]
cls_scores = cls_scores[keep]
dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets.astype(np.float32), NMS_THRESH)
all_boxes[cls_ind] = dets[keep, :]
boxes_this_image = [[]] + [all_boxes[j] for j in range(1, len(CLASSES))]
if vis:
vis_all_detection(im_array, boxes_this_image, CLASSES, 0)
else:
save_all_detection(im_array, boxes_this_image, CLASSES, 0)
def parse_args():
parser = argparse.ArgumentParser(description='Demonstrate a Faster R-CNN network')
parser.add_argument('--image', dest='image', help='custom image', type=str)
parser.add_argument('--prefix', dest='prefix', help='saved model prefix', type=str)
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model', type=int)
parser.add_argument('--gpu', dest='gpu_id', help='GPU device to test with',
default=0, type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
ctx = mx.gpu(args.gpu_id)
detector = get_net(args.prefix, args.epoch, ctx)
demo_net(detector, args.image)