-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathconfig.py
108 lines (90 loc) · 4.37 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Configurate arguments."""
import argparse
# Threholds are collected by `collect_thresholds.py`.
INPUT_IMAGE_SIZE = 512
# 0: confidence, 1: point_shape, 2: offset_x, 3: offset_y, 4: cos(direction),
# 5: sin(direction)
NUM_FEATURE_MAP_CHANNEL = 6
# image_size / 2^5 = 512 / 32 = 16
FEATURE_MAP_SIZE = 16
# Threshold used to filter marking points too close to image boundary
BOUNDARY_THRESH = 0.05
# Thresholds to determine whether an detected point match ground truth.
SQUARED_DISTANCE_THRESH = 0.000277778 # 10 pixel in 600*600 image
DIRECTION_ANGLE_THRESH = 0.5235987755982988 # 30 degree in rad
VSLOT_MIN_DIST = 0.044771278151623496
VSLOT_MAX_DIST = 0.1099427457599304
HSLOT_MIN_DIST = 0.15057789144568634
HSLOT_MAX_DIST = 0.44449496544202816
SHORT_SEPARATOR_LENGTH = 0.199519231
LONG_SEPARATOR_LENGTH = 0.46875
# angle_prediction_error = 0.1384059287593468 collected from evaluate.py
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8
# precision = 0.995585, recall = 0.995805
CONFID_THRESH_FOR_POINT = 0.11676871
def add_common_arguments(parser):
"""Add common arguments for training and inference."""
parser.add_argument('--detector_weights',
help="The weights of pretrained detector.")
parser.add_argument('--depth_factor', type=int, default=32,
help="Depth factor.")
parser.add_argument('--disable_cuda', action='store_true',
help="Disable CUDA.")
parser.add_argument('--gpu_id', type=int, default=0,
help="Select which gpu to use.")
def get_parser_for_training():
"""Return argument parser for training."""
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_directory', required=True,
help="The location of dataset.")
parser.add_argument('--optimizer_weights',
help="The weights of optimizer.")
parser.add_argument('--batch_size', type=int, default=24,
help="Batch size.")
parser.add_argument('--data_loading_workers', type=int, default=32,
help="Number of workers for data loading.")
parser.add_argument('--num_epochs', type=int, default=20,
help="Number of epochs to train for.")
parser.add_argument('--lr', type=float, default=1e-4,
help="The learning rate of back propagation.")
parser.add_argument('--enable_visdom', action='store_true',
help="Enable Visdom to visualize training progress")
add_common_arguments(parser)
return parser
def get_parser_for_evaluation():
"""Return argument parser for testing."""
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_directory', required=True,
help="The location of dataset.")
parser.add_argument('--enable_visdom', action='store_true',
help="Enable Visdom to visualize training progress")
add_common_arguments(parser)
return parser
def get_parser_for_ps_evaluation():
"""Return argument parser for testing."""
parser = argparse.ArgumentParser()
parser.add_argument('--label_directory', required=True,
help="The location of dataset.")
parser.add_argument('--image_directory', required=True,
help="The location of dataset.")
parser.add_argument('--enable_visdom', action='store_true',
help="Enable Visdom to visualize training progress")
add_common_arguments(parser)
return parser
def get_parser_for_inference():
"""Return argument parser for inference."""
parser = argparse.ArgumentParser()
parser.add_argument('--mode', required=True, choices=['image', 'video'],
help="Inference image or video.")
parser.add_argument('--video',
help="Video path if you choose to inference video.")
parser.add_argument('--inference_slot', action='store_true',
help="Perform slot inference.")
parser.add_argument('--thresh', type=float, default=0.5,
help="Detection threshold.")
parser.add_argument('--save', action='store_true',
help="Save detection result to file.")
add_common_arguments(parser)
return parser