-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathps_evaluate.py
81 lines (70 loc) · 2.91 KB
/
ps_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Evaluate directional marking point detector."""
import json
import os
import cv2 as cv
import numpy as np
import torch
import config
import util
from data import match_slots, Slot
from model import DirectionalPointDetector
from inference import detect_marking_points, inference_slots
def get_ground_truths(label):
"""Read label to get ground truth slot."""
slots = np.array(label['slots'])
if slots.size == 0:
return []
if len(slots.shape) < 2:
slots = np.expand_dims(slots, axis=0)
marks = np.array(label['marks'])
if len(marks.shape) < 2:
marks = np.expand_dims(marks, axis=0)
ground_truths = []
for slot in slots:
mark_a = marks[slot[0] - 1]
mark_b = marks[slot[1] - 1]
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
coords = (coords - 0.5) / 600
ground_truths.append(Slot(*coords))
return ground_truths
def psevaluate_detector(args):
"""Evaluate directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(False)
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.eval()
logger = util.Logger(enable_visdom=args.enable_visdom)
ground_truths_list = []
predictions_list = []
for idx, label_file in enumerate(os.listdir(args.label_directory)):
name = os.path.splitext(label_file)[0]
print(idx, name)
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
pred_points = detect_marking_points(
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
slots = []
if pred_points:
marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points)
pred_slots = []
for slot in slots:
point_a = marking_points[slot[0]]
point_b = marking_points[slot[1]]
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
pred_slots.append(
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
predictions_list.append(pred_slots)
with open(os.path.join(args.label_directory, label_file), 'r') as file:
ground_truths_list.append(get_ground_truths(json.load(file)))
precisions, recalls = util.calc_precision_recall(
ground_truths_list, predictions_list, match_slots)
average_precision = util.calc_average_precision(precisions, recalls)
if args.enable_visdom:
logger.plot_curve(precisions, recalls)
logger.log(average_precision=average_precision)
if __name__ == '__main__':
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())