-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmask-rcnn-prediction.py
50 lines (39 loc) · 2.53 KB
/
mask-rcnn-prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import mrcnn
import mrcnn.config
import mrcnn.model
import mrcnn.visualize
import cv2
import os
# load the class label names from disk, one label per line
# CLASS_NAMES = open("coco_labels.txt").read().strip().split("\n")
CLASS_NAMES = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
class SimpleConfig(mrcnn.config.Config):
# Give the configuration a recognizable name
NAME = "coco_inference"
# set the number of GPUs to use along with the number of images per GPU
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Number of classes = number of classes + 1 (+1 for the background). The background class is named BG
NUM_CLASSES = len(CLASS_NAMES)
# Initialize the Mask R-CNN model for inference and then load the weights.
# This step builds the Keras model architecture.
model = mrcnn.model.MaskRCNN(mode="inference",
config=SimpleConfig(),
model_dir=os.getcwd())
# Load the weights into the model.
model.load_weights(filepath="mask_rcnn_coco.h5",
by_name=True)
# load the input image, convert it from BGR to RGB channel
image = cv2.imread("sample_image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Perform a forward pass of the network to obtain the results
r = model.detect([image], verbose=0)
# Get the results for the first image.
r = r[0]
# Visualize the detected objects.
mrcnn.visualize.display_instances(image=image,
boxes=r['rois'],
masks=r['masks'],
class_ids=r['class_ids'],
class_names=CLASS_NAMES,
scores=r['scores'])