From 1108eafa95e9a6e03d5fab5bcf2b67c540be4bc4 Mon Sep 17 00:00:00 2001 From: enes Date: Wed, 16 Oct 2019 10:00:57 +0300 Subject: [PATCH 1/3] Batch inference added --- darknet.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++ include/darknet.h | 6 ++++ src/network.c | 81 +++++++++++++++++++++++++++++++++++++++++++++++ src/yolo_layer.c | 48 ++++++++++++++++++++++++++++ src/yolo_layer.h | 2 ++ 5 files changed, 218 insertions(+) diff --git a/darknet.py b/darknet.py index 10c9a456ebd..35987465da1 100644 --- a/darknet.py +++ b/darknet.py @@ -61,6 +61,9 @@ class DETECTION(Structure): ("objectness", c_float), ("sort_class", c_int)] +class DETNUMPAIR(Structure): + _fields_ = [("num", c_int), + ("dets", POINTER(DETECTION))] class IMAGE(Structure): _fields_ = [("w", c_int), @@ -157,6 +160,9 @@ def network_height(net): free_detections = lib.free_detections free_detections.argtypes = [POINTER(DETECTION), c_int] +free_batch_detections = lib.free_batch_detections +free_batch_detections.argtypes = [POINTER(DETNUMPAIR), c_int] + free_ptrs = lib.free_ptrs free_ptrs.argtypes = [POINTER(c_void_p), c_int] @@ -206,6 +212,11 @@ def network_height(net): predict_image_letterbox.argtypes = [c_void_p, IMAGE] predict_image_letterbox.restype = POINTER(c_float) +network_predict_custom = lib.network_predict_custom +network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int, + c_float, c_float, POINTER(c_int), c_int, c_int] +network_predict_custom.restype = POINTER(DETNUMPAIR) + def array_to_image(arr): import numpy as np # need to return old values to avoid python freeing memory @@ -441,5 +452,75 @@ def performDetect(imagePath="data/dog.jpg", thresh= 0.25, configPath = "./cfg/yo print("Unable to show image: "+str(e)) return detections +def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath = "yolov3.weights", metaPath= "./cfg/coco.data", hier_thresh=.5, nms=.45, batch_size=3): + import cv2 + import numpy as np + # NB! Image sizes should be the same + # You can change the images, yet, be sure that they have the same width and height + img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg'] + image_list = [cv2.imread(k) for k in img_samples] + + if len(image_list) > batch_size: + raise ValueError( + "Please check if batch size is equal to the number of images passed to the function") + net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size) + meta = load_meta(metaPath.encode('utf-8')) + pred_height, pred_width, c = image_list[0].shape + net_width, net_height = (network_width(net), network_height(net)) + img_list = [] + for custom_image_bgr in image_list: + custom_image = cv2.cvtColor(custom_image_bgr, cv2.COLOR_BGR2RGB) + custom_image = cv2.resize( + custom_image, (net_width, net_height), interpolation=cv2.INTER_NEAREST) + custom_image = custom_image.transpose(2, 0, 1) + img_list.append(custom_image) + + arr = np.concatenate(img_list, axis=0) + arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0 + data = arr.ctypes.data_as(POINTER(c_float)) + im = IMAGE(net_width, net_height, c, data) + + batch_dets = network_predict_custom(net, im, batch_size, pred_width, + pred_height, thresh, hier_thresh, None, 0, 0) + batch_boxes = [] + batch_scores = [] + batch_classes = [] + for b in range(batch_size): + num = batch_dets[b].num + dets = batch_dets[b].dets + if nms: + do_nms_obj(dets, num, meta.classes, nms) + boxes = [] + scores = [] + classes = [] + for i in range(num): + det = dets[i] + score = -1 + label = None + for c in range(det.classes): + p = det.prob[c] + if p > score: + score = p + label = c + if score > thresh: + box = det.bbox + left, top, right, bottom = map(int,(box.x - box.w / 2, box.y - box.h / 2, + box.x + box.w / 2, box.y + box.h / 2)) + boxes.append((top, left, bottom, right)) + scores.append(score) + classes.append(label) + boxColor = (int(255 * (1 - (score ** 2))), int(255 * (score ** 2)), 0) + cv2.rectangle(image_list[b], (left, top), + (right, bottom), boxColor, 2) + cv2.imwrite(os.path.basename(img_samples[b]),image_list[b]) + + batch_boxes.append(boxes) + batch_scores.append(scores) + batch_classes.append(classes) + free_batch_detections(batch_dets, batch_size) + return batch_boxes, batch_scores, batch_classes + if __name__ == "__main__": print(performDetect()) + # Uncomment the following line to see batch inference working + #print(performBatchDetect()) \ No newline at end of file diff --git a/include/darknet.h b/include/darknet.h index e78abe6a5c9..f0648066b7e 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -730,6 +730,12 @@ typedef struct detection{ int sort_class; } detection; +// network.c -batch inference +typedef struct detNumPair { + int num; + detection *dets; +} detNumPair, *pdetNumPair; + // matrix.h typedef struct matrix { int rows, cols; diff --git a/src/network.c b/src/network.c index 82dc4d53978..c1de0a4636c 100644 --- a/src/network.c +++ b/src/network.c @@ -694,6 +694,22 @@ int num_detections(network *net, float thresh) return s; } +int num_detections_custom(network *net, float thresh, int b) +{ + int i; + int s = 0; + for (i = 0; i < net->n; ++i) { + layer l = net->layers[i]; + if (l.type == YOLO) { + s += yolo_num_detections_custom(l, thresh, b); + } + if (l.type == DETECTION || l.type == REGION) { + s += l.w*l.h*l.n; + } + } + return s; +} + detection *make_network_boxes(network *net, float thresh, int *num) { layer l = net->layers[net->n - 1]; @@ -710,6 +726,21 @@ detection *make_network_boxes(network *net, float thresh, int *num) return dets; } +detection *make_network_boxes_custom(network *net, float thresh, int *num, int batch) +{ + int i; + layer l = net->layers[net->n - 1]; + int nboxes = num_detections_custom(net, thresh, batch); + if (num) *num = nboxes; + detection* dets = (detection*)calloc(nboxes, sizeof(detection)); + for (i = 0; i < nboxes; ++i) { + dets[i].prob = (float*)calloc(l.classes, sizeof(float)); + if (l.coords > 4) { + dets[i].mask = (float*)calloc(l.coords - 4, sizeof(float)); + } + } + return dets; +} void custom_get_region_detections(layer l, int w, int h, int net_w, int net_h, float thresh, int *map, float hier, int relative, detection *dets, int letter) { @@ -761,6 +792,33 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in } } +void fill_network_boxes_custom(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch) +{ + int prev_classes = -1; + int j; + for (j = 0; j < net->n; ++j) { + layer l = net->layers[j]; + if (l.type == YOLO) { + int count = get_yolo_detections_custom(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch); + dets += count; + if (prev_classes < 0) prev_classes = l.classes; + else if (prev_classes != l.classes) { + printf(" Error: Different [yolo] layers have different number of classes = %d and %d - check your cfg-file! \n", + prev_classes, l.classes); + } + } + if (l.type == REGION) { + custom_get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets, letter); + //get_region_detections(l, w, h, net->w, net->h, thresh, map, hier, relative, dets); + dets += l.w*l.h*l.n; + } + if (l.type == DETECTION) { + get_detection_detections(l, w, h, thresh, dets); + dets += l.w*l.h*l.n; + } + } +} + detection *get_network_boxes(network *net, int w, int h, float thresh, float hier, int *map, int relative, int *num, int letter) { detection *dets = make_network_boxes(net, thresh, num); @@ -778,6 +836,14 @@ void free_detections(detection *dets, int n) free(dets); } +void free_batch_detections(detNumPair *detNumPairs, int n) +{ + int i; + for(i=0; iw, net->h); diff --git a/src/yolo_layer.c b/src/yolo_layer.c index 20ee8e34391..6b589d71e38 100644 --- a/src/yolo_layer.c +++ b/src/yolo_layer.c @@ -461,6 +461,21 @@ int yolo_num_detections(layer l, float thresh) return count; } +int yolo_num_detections_custom(layer l, float thresh, int batch) +{ + int i, n; + int count = 0; + for (i = 0; i < l.w*l.h; ++i){ + for(n = 0; n < l.n; ++n){ + int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4); + if(l.output[obj_index] > thresh){ + ++count; + } + } + } + return count; +} + void avg_flipped_yolo(layer l) { int i,j,n,z; @@ -522,6 +537,39 @@ int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, return count; } +int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter,int batch) +{ + //printf("\n l.batch = %d, l.w = %d, l.h = %d, l.n = %d \n", l.batch, l.w, l.h, l.n); + int i,j,n; + float *predictions = l.output; + //if (l.batch == 2) avg_flipped_yolo(l); + int count = 0; + for (i = 0; i < l.w*l.h; ++i){ + int row = i / l.w; + int col = i % l.w; + for(n = 0; n < l.n; ++n){ + int obj_index = entry_index(l, batch, n*l.w*l.h + i, 4); + float objectness = predictions[obj_index]; + //if(objectness <= thresh) continue; // incorrect behavior for Nan values + if (objectness > thresh) { + //printf("\n objectness = %f, thresh = %f, i = %d, n = %d \n", objectness, thresh, i, n); + int box_index = entry_index(l, batch, n*l.w*l.h + i, 0); + dets[count].bbox = get_yolo_box(predictions, l.biases, l.mask[n], box_index, col, row, l.w, l.h, netw, neth, l.w*l.h); + dets[count].objectness = objectness; + dets[count].classes = l.classes; + for (j = 0; j < l.classes; ++j) { + int class_index = entry_index(l, batch, n*l.w*l.h + i, 4 + 1 + j); + float prob = objectness*predictions[class_index]; + dets[count].prob[j] = (prob > thresh) ? prob : 0; + } + ++count; + } + } + } + correct_yolo_boxes(dets, count, w, h, netw, neth, relative, letter); + return count; +} + #ifdef GPU void forward_yolo_layer_gpu(const layer l, network_state state) diff --git a/src/yolo_layer.h b/src/yolo_layer.h index d67482fe2fb..83607452851 100644 --- a/src/yolo_layer.h +++ b/src/yolo_layer.h @@ -13,7 +13,9 @@ void forward_yolo_layer(const layer l, network_state state); void backward_yolo_layer(const layer l, network_state state); void resize_yolo_layer(layer *l, int w, int h); int yolo_num_detections(layer l, float thresh); +int yolo_num_detections_custom(layer l, float thresh, int batch); int get_yolo_detections(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter); +int get_yolo_detections_custom(layer l, int w, int h, int netw, int neth, float thresh, int *map, int relative, detection *dets, int letter, int batch); void correct_yolo_boxes(detection *dets, int n, int w, int h, int netw, int neth, int relative, int letter); #ifdef GPU From c435fca73874d69dd6bdeb988fd45a0e8fb9a9c1 Mon Sep 17 00:00:00 2001 From: enes Date: Wed, 16 Oct 2019 15:08:37 +0300 Subject: [PATCH 2/3] explicit casting for c++ --- src/network.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/network.c b/src/network.c index c1de0a4636c..ab7533b13bf 100644 --- a/src/network.c +++ b/src/network.c @@ -919,7 +919,7 @@ detNumPair* network_predict_custom(network *net, image im, int batch, int w, int { set_batch_network(net, batch); network_predict(*net, im.data); - detNumPair *pdets = malloc(batch*sizeof(detNumPair)); + detNumPair *pdets = ( struct detNumPair * )malloc(batch*sizeof(detNumPair)); int num; for(int b=0;b Date: Sat, 19 Oct 2019 16:18:44 +0300 Subject: [PATCH 3/3] batch inference refactoring --- darknet.py | 17 +++++++---------- src/network.c | 32 ++++++++++++++++---------------- src/yolo_layer.c | 5 ++--- src/yolo_layer.h | 4 ++-- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/darknet.py b/darknet.py index 35987465da1..1dde00fbab1 100644 --- a/darknet.py +++ b/darknet.py @@ -212,10 +212,10 @@ def network_height(net): predict_image_letterbox.argtypes = [c_void_p, IMAGE] predict_image_letterbox.restype = POINTER(c_float) -network_predict_custom = lib.network_predict_custom -network_predict_custom.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int, +network_predict_batch = lib.network_predict_batch +network_predict_batch.argtypes = [c_void_p, IMAGE, c_int, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, c_int] -network_predict_custom.restype = POINTER(DETNUMPAIR) +network_predict_batch.restype = POINTER(DETNUMPAIR) def array_to_image(arr): import numpy as np @@ -460,9 +460,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath img_samples = ['data/person.jpg', 'data/person.jpg', 'data/person.jpg'] image_list = [cv2.imread(k) for k in img_samples] - if len(image_list) > batch_size: - raise ValueError( - "Please check if batch size is equal to the number of images passed to the function") net = load_net_custom(configPath.encode('utf-8'), weightPath.encode('utf-8'), 0, batch_size) meta = load_meta(metaPath.encode('utf-8')) pred_height, pred_width, c = image_list[0].shape @@ -480,7 +477,7 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath data = arr.ctypes.data_as(POINTER(c_float)) im = IMAGE(net_width, net_height, c, data) - batch_dets = network_predict_custom(net, im, batch_size, pred_width, + batch_dets = network_predict_batch(net, im, batch_size, pred_width, pred_height, thresh, hier_thresh, None, 0, 0) batch_boxes = [] batch_scores = [] @@ -521,6 +518,6 @@ def performBatchDetect(thresh= 0.25, configPath = "./cfg/yolov3.cfg", weightPath return batch_boxes, batch_scores, batch_classes if __name__ == "__main__": - print(performDetect()) - # Uncomment the following line to see batch inference working - #print(performBatchDetect()) \ No newline at end of file + #print(performDetect()) + #Uncomment the following line to see batch inference working + print(performBatchDetect()) \ No newline at end of file diff --git a/src/network.c b/src/network.c index ab7533b13bf..c3a195a70e5 100644 --- a/src/network.c +++ b/src/network.c @@ -694,14 +694,14 @@ int num_detections(network *net, float thresh) return s; } -int num_detections_custom(network *net, float thresh, int b) +int num_detections_batch(network *net, float thresh, int batch) { int i; int s = 0; for (i = 0; i < net->n; ++i) { layer l = net->layers[i]; if (l.type == YOLO) { - s += yolo_num_detections_custom(l, thresh, b); + s += yolo_num_detections_batch(l, thresh, batch); } if (l.type == DETECTION || l.type == REGION) { s += l.w*l.h*l.n; @@ -726,12 +726,13 @@ detection *make_network_boxes(network *net, float thresh, int *num) return dets; } -detection *make_network_boxes_custom(network *net, float thresh, int *num, int batch) +detection *make_network_boxes_batch(network *net, float thresh, int *num, int batch) { int i; layer l = net->layers[net->n - 1]; - int nboxes = num_detections_custom(net, thresh, batch); - if (num) *num = nboxes; + int nboxes = num_detections_batch(net, thresh, batch); + assert(num != NULL); + *num = nboxes; detection* dets = (detection*)calloc(nboxes, sizeof(detection)); for (i = 0; i < nboxes; ++i) { dets[i].prob = (float*)calloc(l.classes, sizeof(float)); @@ -792,14 +793,14 @@ void fill_network_boxes(network *net, int w, int h, float thresh, float hier, in } } -void fill_network_boxes_custom(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch) +void fill_network_boxes_batch(network *net, int w, int h, float thresh, float hier, int *map, int relative, detection *dets, int letter, int batch) { int prev_classes = -1; int j; for (j = 0; j < net->n; ++j) { layer l = net->layers[j]; if (l.type == YOLO) { - int count = get_yolo_detections_custom(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch); + int count = get_yolo_detections_batch(l, w, h, net->w, net->h, thresh, map, relative, dets, letter, batch); dets += count; if (prev_classes < 0) prev_classes = l.classes; else if (prev_classes != l.classes) { @@ -840,7 +841,7 @@ void free_batch_detections(detNumPair *detNumPairs, int n) { int i; for(i=0; i