Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

add prepare_model_param and preset_params for all models #869

Open
wants to merge 74 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
4b69c8a
set fcis params in _models
knorth55 Feb 19, 2019
6341172
update train code for fcis params
knorth55 Feb 19, 2019
52698eb
update eval_instance_segmentation
knorth55 Feb 19, 2019
074525f
split prepare_param and prepare_pretrained_model
knorth55 Mar 23, 2019
1eb6d5f
use preset_param for fcis default param
knorth55 Mar 23, 2019
d650828
use preset_param in examples/fcis
knorth55 Mar 23, 2019
dbbd90c
use preset_param in eval_instance_segmentation
knorth55 Mar 23, 2019
5e38b79
refactor pretrained_model arg
knorth55 Mar 24, 2019
19da306
remove iter2 from fcis parameter
knorth55 Mar 24, 2019
680d2b6
add initialW as preset_param in fcis_resnet101
knorth55 Mar 24, 2019
5d5dfbb
remove iter2 from train scripts
knorth55 Mar 24, 2019
bf7fd78
use preset_param in resnet
knorth55 Mar 24, 2019
c0bf377
fix typo in prepare_pretrained_model
knorth55 Mar 24, 2019
02306ee
fix default iter2 value
knorth55 Mar 24, 2019
92ef569
fix typo in prepare_param
knorth55 Mar 26, 2019
f738136
refactor prepare_param
knorth55 Mar 26, 2019
68809ee
merge into one function
Hakuyume Mar 26, 2019
847daa1
update FCIS
Hakuyume Mar 26, 2019
cc0df20
update examples
Hakuyume Mar 26, 2019
0a6be8d
update ResNet
Hakuyume Mar 26, 2019
b6e4efa
fix
Hakuyume Mar 26, 2019
1ef4c72
support python2
knorth55 Mar 26, 2019
365f448
update test_fcis to pass test
knorth55 Apr 4, 2019
a2c8338
use preset_params in vgg
knorth55 May 3, 2019
4dc1984
use preset_params in faster_rcnn_vgg
knorth55 May 3, 2019
2d1689c
support preset_params in test_faster_rcnn_vgg.py
knorth55 May 3, 2019
07d5439
use preset_params for faster_rcnn examples
knorth55 May 4, 2019
59147bb
use preset_params for ssd_vgg16
knorth55 May 4, 2019
0d3cbf1
use preset_params in ssd examples
knorth55 May 4, 2019
c5d9104
use preset_params in segnet
knorth55 May 4, 2019
7820e3b
use preset_params in segnet examples
knorth55 May 4, 2019
4ae7cc5
use preset_params in fpn for default param
knorth55 May 15, 2019
8799fef
use preset_params in examples/fpn
knorth55 May 15, 2019
01468da
use preset_params in test_faster_rcnn_fpn_resnet
knorth55 May 15, 2019
9c09241
use preset_params in yolo
knorth55 May 15, 2019
4196597
use preset_params in examples/yolo
knorth55 May 15, 2019
b80fa1a
use preset_params in detection/eval_detection
knorth55 May 15, 2019
0906cdb
fix typo in test_faster_rcnn_fpn_resnet
knorth55 May 15, 2019
c447dcb
use preset_params in pspnet
knorth55 May 15, 2019
7aee72f
use preset_params in examples/pspnet
knorth55 May 15, 2019
053c35b
fix test_yolo
knorth55 May 15, 2019
d9f49b6
use preset_params in test_yolo_v2_tiny
knorth55 May 15, 2019
d9775a3
fix test_pspnet
knorth55 May 15, 2019
2f41e6d
use preset_params in senet
knorth55 May 15, 2019
e4992e3
use preset_params in test_se_resnet
knorth55 May 15, 2019
44eae14
use preset_params in resnet_tests
knorth55 May 15, 2019
cc37c79
use preset_params in fcis_tests
knorth55 May 15, 2019
d4a8fc8
use preset_params in deeplab
knorth55 May 15, 2019
951c191
use preset_params in deeplab_tests
knorth55 May 15, 2019
918ce86
fix doc
knorth55 May 15, 2019
ccb3814
fix test_fcis_resnet101
knorth55 May 15, 2019
21eab5b
refactor eval_imagenet to use preset_params
knorth55 May 15, 2019
1f1293c
use preset_params in train_imagenet_multi
knorth55 May 15, 2019
be0bf03
use preset_params in deeplab demo.py
knorth55 May 15, 2019
c6c153c
add args.dataset in demo.py
knorth55 May 15, 2019
437de77
refactor ssd train example script
knorth55 May 15, 2019
9051b16
use preset_params in eval_detection
knorth55 May 16, 2019
9cc2079
use preset_params in eval_instance_segmentation
knorth55 May 16, 2019
20504f7
use preset_params in eval_semantic_segmentation
knorth55 May 15, 2019
1661583
use preset_params in test_vgg
knorth55 May 16, 2019
5cf1922
remove n_fg_class None check from Pretrained model tests
knorth55 May 16, 2019
6911a2e
fix typo in eval_semantic_segmentation_multi.py
knorth55 May 16, 2019
4051c34
fix typo in faster_rcnn_fpn_resnet
knorth55 May 16, 2019
45af569
use VGG preset_params correctly
knorth55 May 16, 2019
86f4641
use preset_params in vgg_tests
knorth55 May 16, 2019
8ea3b8a
copy param in tests
knorth55 May 16, 2019
c466451
fix typo in tests
knorth55 May 16, 2019
4cec8da
use params copy in examples
knorth55 May 16, 2019
979b479
use preset_params in model conversion scripts
knorth55 May 16, 2019
efebb57
use deepcopy in examples
knorth55 May 17, 2019
1d9295c
use deepcopy in faster_rcnn_fpn_resnet
knorth55 May 17, 2019
b914fe8
use deepcopy in tests
knorth55 May 18, 2019
c71d7c4
remove unused lines
knorth55 May 18, 2019
a79c5d2
Merge branch 'master' into prepare-model-param
knorth55 May 29, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 87 additions & 69 deletions chainercv/experimental/links/model/fcis/fcis_resnet101.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ class FCISResNet101(FCIS):
localization estimates.
loc_normalize_std (tupler of four floats): Standard deviation
of localization estimates.
iter2 (bool): if the value is set :obj:`True`, Position Sensitive
ROI pooling is executed twice. In the second time, Position
Sensitive ROI pooling uses improved ROIs by the localization
parameters calculated in the first time.
resnet_initialW (callable): Initializer for the layers corresponding to
the ResNet101 layers.
rpn_initialW (callable): Initializer for Region Proposal Network
Expand All @@ -79,90 +75,112 @@ class FCISResNet101(FCIS):
:class:`~chainercv.links.model.faster_rcnn.ProposalCreator`.

"""
_common_param = {
'feat_stride': 16,
'min_size': 600,
'max_size': 1000,
'roi_size': 21,
'group_size': 7,
'ratios': [0.5, 1, 2],
'loc_normalize_mean': (0.0, 0.0, 0.0, 0.0),
'loc_normalize_std': (0.2, 0.2, 0.5, 0.5),
'rpn_initialW': chainer.initializers.Normal(0.01),
'resnet_initialW': chainer.initializers.constant.Zero(),
'head_initialW': chainer.initializers.Normal(0.01),
}
preset_params = {
'sbd': dict({
'n_fg_class': 20,
'anchor_scales': (8, 16, 32),
'proposal_creator_params': {
'nms_thresh': 0.7,
'n_train_pre_nms': 6000,
'n_train_post_nms': 300,
'n_test_pre_nms': 6000,
'n_test_post_nms': 300,
'force_cpu_nms': False,
'min_size': 16,
}},
**_common_param),
'coco': dict({
'n_fg_class': 80,
'anchor_scales': (4, 8, 16, 32),
'proposal_creator_params': {
'nms_thresh': 0.7,
'n_train_pre_nms': 6000,
'n_train_post_nms': 300,
'n_test_pre_nms': 6000,
'n_test_post_nms': 300,
'force_cpu_nms': False,
'min_size': 2,
}},
**_common_param),
}

_models = {
'sbd': {
'param': {'n_fg_class': 20},
'param': preset_params['sbd'],
'url': 'https://chainercv-models.preferred.jp/'
'fcis_resnet101_sbd_trained_2018_06_22.npz',
'cv2': True
'fcis_resnet101_sbd_trained_2018_06_22.npz',
'cv2': True,
},
'sbd_converted': {
'param': {'n_fg_class': 20},
'param': preset_params['sbd'],
'url': 'https://chainercv-models.preferred.jp/'
'fcis_resnet101_sbd_converted_2018_07_02.npz',
'cv2': True
'fcis_resnet101_sbd_converted_2018_07_02.npz',
'cv2': True,
},
'coco': {
'param': {'n_fg_class': 80},
'param': preset_params['coco'],
'url': 'https://chainercv-models.preferred.jp/'
'fcis_resnet101_coco_trained_2019_01_30.npz',
'cv2': True
'fcis_resnet101_coco_trained_2019_01_30.npz',
'cv2': True,
},
'coco_converted': {
'param': {'n_fg_class': 80},
'param': preset_params['coco'],
'url': 'https://chainercv-models.preferred.jp/'
'fcis_resnet101_coco_converted_2019_01_30.npz',
'cv2': True
}
}
feat_stride = 16
proposal_creator_params = {
'nms_thresh': 0.7,
'n_train_pre_nms': 6000,
'n_train_post_nms': 300,
'n_test_pre_nms': 6000,
'n_test_post_nms': 300,
'force_cpu_nms': False,
'min_size': 16
'fcis_resnet101_coco_converted_2019_01_30.npz',
'cv2': True,
},
}

def __init__(
self,
n_fg_class=None,
pretrained_model=None,
min_size=600, max_size=1000,
roi_size=21, group_size=7,
ratios=[0.5, 1, 2], anchor_scales=[8, 16, 32],
loc_normalize_mean=(0.0, 0.0, 0.0, 0.0),
loc_normalize_std=(0.2, 0.2, 0.5, 0.5),
iter2=True,
resnet_initialW=None, rpn_initialW=None, head_initialW=None,
proposal_creator_params=None):
param, path = utils.prepare_pretrained_model(
{'n_fg_class': n_fg_class}, pretrained_model, self._models)

if rpn_initialW is None:
rpn_initialW = chainer.initializers.Normal(0.01)
if resnet_initialW is None and pretrained_model:
resnet_initialW = chainer.initializers.constant.Zero()
if proposal_creator_params is not None:
self.proposal_creator_params = proposal_creator_params

extractor = ResNet101Extractor(
initialW=resnet_initialW)
feat_stride=None,
min_size=None, max_size=None,
roi_size=None, group_size=None,
ratios=None, anchor_scales=None,
loc_normalize_mean=None, loc_normalize_std=None,
rpn_initialW=None, resnet_initialW=None, head_initialW=None,
proposal_creator_params=None,
):
param, path = utils.prepare_model_param(locals(), self._models)

extractor = ResNet101Extractor(initialW=param['resnet_initialW'])
rpn = RegionProposalNetwork(
1024, 512,
ratios=ratios,
anchor_scales=anchor_scales,
feat_stride=self.feat_stride,
initialW=rpn_initialW,
proposal_creator_params=self.proposal_creator_params)
ratios=param['ratios'],
anchor_scales=param['anchor_scales'],
feat_stride=param['feat_stride'],
initialW=param['rpn_initialW'],
proposal_creator_params=param['proposal_creator_params'])
head = FCISResNet101Head(
param['n_fg_class'] + 1,
roi_size=roi_size, group_size=group_size,
spatial_scale=1. / self.feat_stride,
loc_normalize_mean=loc_normalize_mean,
loc_normalize_std=loc_normalize_std,
iter2=iter2, initialW=head_initialW)
roi_size=param['roi_size'], group_size=param['group_size'],
spatial_scale=1. / param['feat_stride'],
loc_normalize_mean=param['loc_normalize_mean'],
loc_normalize_std=param['loc_normalize_std'],
initialW=param['head_initialW'])

mean = np.array([123.15, 115.90, 103.06],
dtype=np.float32)[:, None, None]
mean = np.array(
[123.15, 115.90, 103.06], dtype=np.float32)[:, None, None]

super(FCISResNet101, self).__init__(
extractor, rpn, head,
mean, min_size, max_size,
loc_normalize_mean, loc_normalize_std)
mean, param['min_size'], param['max_size'],
param['loc_normalize_mean'], param['loc_normalize_std'])

if path == 'imagenet':
self._copy_imagenet_pretrained_resnet()
Expand Down Expand Up @@ -269,10 +287,6 @@ class FCISResNet101Head(chainer.Chain):
localization estimates.
loc_normalize_std (tupler of four floats): Standard deviation
of localization estimates.
iter2 (bool): if the value is set :obj:`True`, Position Sensitive
ROI pooling is executed twice. In the second time, Position
Sensitive ROI pooling uses improved ROIs by the localization
parameters calculated in the first time.
initialW (callable): Initializer for the layers.

"""
Expand All @@ -282,7 +296,7 @@ def __init__(
n_class,
roi_size, group_size, spatial_scale,
loc_normalize_mean, loc_normalize_std,
iter2, initialW=None
initialW=None
):
super(FCISResNet101Head, self).__init__()

Expand All @@ -295,7 +309,6 @@ def __init__(
self.roi_size = roi_size
self.loc_normalize_mean = loc_normalize_mean
self.loc_normalize_std = loc_normalize_std
self.iter2 = iter2

with self.init_scope():
self.conv1 = L.Convolution2D(
Expand All @@ -307,7 +320,8 @@ def __init__(
1024, group_size * group_size * 2 * 4,
1, 1, 0, initialW=initialW)

def forward(self, x, rois, roi_indices, img_size, gt_roi_labels=None):
def forward(self, x, rois, roi_indices, img_size,
gt_roi_labels=None, iter2=True):
"""Forward the chain.

We assume that there are :math:`N` batches.
Expand All @@ -323,6 +337,10 @@ def forward(self, x, rois, roi_indices, img_size, gt_roi_labels=None):
roi_indices (array): An array containing indices of images to
which bounding boxes correspond to. Its shape is :math:`(R',)`.
img_size (tuple of int): A tuple containing image size.
iter2 (bool): if the value is set :obj:`True`, Position Sensitive
ROI pooling is executed twice. In the second time, Position
Sensitive ROI pooling uses improved ROIs by the localization
parameters calculated in the first time.

"""
h = F.relu(self.conv1(x))
Expand All @@ -332,7 +350,7 @@ def forward(self, x, rois, roi_indices, img_size, gt_roi_labels=None):
# PSROI pooling and regression
roi_ag_seg_scores, roi_ag_locs, roi_cls_scores = self._pool(
h_cls_seg, h_ag_loc, rois, roi_indices, gt_roi_labels)
if self.iter2:
if iter2:
# 2nd Iteration
# get rois2 for more precise prediction
roi_ag_locs = roi_ag_locs.array
Expand Down
3 changes: 2 additions & 1 deletion chainercv/experimental/links/model/fcis/fcis_train_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def forward(self, imgs, masks, labels, bboxes, scale):
sample_roi_index = self.xp.zeros(
(len(sample_roi),), dtype=np.int32)
roi_ag_seg_score, roi_ag_loc, roi_cls_score, _, _ = self.fcis.head(
roi_features, sample_roi, sample_roi_index, img_size, gt_roi_label)
roi_features, sample_roi, sample_roi_index, img_size,
gt_roi_label, iter2=False)

# RPN losses
gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(
Expand Down
62 changes: 44 additions & 18 deletions chainercv/experimental/links/model/pspnet/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def forward(self, x):

class DilatedResNet(PickableSequentialChain):

preset_params = {
'imagenet': {
'initialW': initializers.constant.Zero(),
}
}
_blocks = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
Expand All @@ -60,13 +65,15 @@ class DilatedResNet(PickableSequentialChain):
_models = {
50: {
'imagenet': {
'param': preset_params['imagenet'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet50_imagenet_trained_2018_11_26.npz',
'cv2': True
},
},
101: {
'imagenet': {
'param': preset_params['imagenet'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet101_imagenet_trained_2018_11_26.npz',
'cv2': True
Expand All @@ -78,10 +85,9 @@ def __init__(self, n_layer, pretrained_model=None,
initialW=None):
n_block = self._blocks[n_layer]

_, path = utils.prepare_pretrained_model(
{},
pretrained_model,
self._models[n_layer])
param, path = utils.prepare_model_param(
locals(), self._models[n_layer])
initialW = param['initialW']

super(DilatedResNet, self).__init__()
with self.init_scope():
Expand Down Expand Up @@ -156,23 +162,19 @@ def __init__(self, n_class=None, pretrained_model=None,
else:
extractor_pretrained_model = None

param, path = utils.prepare_pretrained_model(
{'n_class': n_class, 'input_size': input_size},
pretrained_model, self._models,
default={'input_size': (713, 713)})
param, path = utils.prepare_model_param(locals(), self._models)
n_class = param['n_class']
input_size = param['input_size']
initialW = param['initialW']
if not isinstance(input_size, (list, tuple)):
input_size = (int(input_size), int(input_size))
self.input_size = input_size

if initialW is None:
if pretrained_model:
initialW = initializers.constant.Zero()

kwargs = self._extractor_kwargs
kwargs.update({'pretrained_model': extractor_pretrained_model,
'initialW': initialW})
kwargs.update({'pretrained_model': extractor_pretrained_model})
if extractor_pretrained_model in self._extractor_cls.preset_params:
kwargs.update(
self._extractor_cls.preset_params[extractor_pretrained_model])
extractor = self._extractor_cls(**kwargs)
extractor.pick = self._extractor_pick

Expand Down Expand Up @@ -302,17 +304,29 @@ class PSPNetResNet101(PSPNet):

"""

preset_params = {
'cityscapes': {
'n_class': 19,
'input_size': (713, 713),
'initialW': initializers.constant.Zero(),
},
'ade20k': {
'n_class': 150,
'input_size': (473, 473),
'initialW': initializers.constant.Zero(),
},
}
_extractor_cls = DilatedResNet
_extractor_kwargs = {'n_layer': 101}
_extractor_pick = ('res4', 'res5')
_models = {
'cityscapes': {
'param': {'n_class': 19, 'input_size': (713, 713)},
'param': preset_params['cityscapes'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet101_cityscapes_trained_2018_12_19.npz',
},
'ade20k': {
'param': {'n_class': 150, 'input_size': (473, 473)},
'param': preset_params['ade20k'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet101_ade20k_trained_2018_12_23.npz',
},
Expand All @@ -328,17 +342,29 @@ class PSPNetResNet50(PSPNet):

"""

preset_params = {
'cityscapes': {
'n_class': 19,
'input_size': (713, 713),
'initialW': initializers.constant.Zero(),
},
'ade20k': {
'n_class': 150,
'input_size': (473, 473),
'initialW': initializers.constant.Zero(),
},
}
_extractor_cls = DilatedResNet
_extractor_kwargs = {'n_layer': 50}
_extractor_pick = ('res4', 'res5')
_models = {
'cityscapes': {
'param': {'n_class': 19, 'input_size': (713, 713)},
'param': preset_params['cityscapes'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet50_cityscapes_trained_2018_12_19.npz',
},
'ade20k': {
'param': {'n_class': 150, 'input_size': (473, 473)},
'param': preset_params['ade20k'],
'url': 'https://chainercv-models.preferred.jp/'
'pspnet_resnet50_ade20k_trained_2018_12_23.npz',
},
Expand Down
Loading