Skip to content

Commit

Permalink
update detection configs
Browse files Browse the repository at this point in the history
  • Loading branch information
czczup committed May 31, 2022
1 parent ed505bf commit b8857ea
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 6 deletions.
21 changes: 21 additions & 0 deletions detection/configs/cascade_rcnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Cascade R-CNN

> [Cascade R-CNN: High Quality Object Detection and Instance Segmentation](https://arxiv.org/abs/1906.09756)
<!-- [ALGORITHM] -->

## Introduction

In object detection, the intersection over union (IoU) threshold is frequently used to define positives/negatives. The threshold used to train a detector defines its quality. While the commonly used threshold of 0.5 leads to noisy (low-quality) detections, detection performance frequently degrades for larger thresholds. This paradox of high-quality detection has two causes: 1) overfitting, due to vanishing positive samples for large thresholds, and 2) inference-time quality mismatch between detector and test hypotheses. A multi-stage object detection architecture, the Cascade R-CNN, composed of a sequence of detectors trained with increasing IoU thresholds, is proposed to address these problems. The detectors are trained sequentially, using the output of a detector as training set for the next. This resampling progressively improves hypotheses quality, guaranteeing a positive training set of equivalent size for all detectors and minimizing overfitting. The same cascade is applied at inference, to eliminate quality mismatches between hypotheses and detectors. An implementation of the Cascade R-CNN without bells or whistles achieves state-of-the-art performance on the COCO dataset, and significantly improves high-quality detection on generic and specific object detection datasets, including VOC, KITTI, CityPerson, and WiderFace. Finally, the Cascade R-CNN is generalized to instance segmentation, with nontrivial improvements over the Mask R-CNN.

<div align=center>
<img src="https://user-images.githubusercontent.com/40661020/143872197-d99b90e4-4f05-4329-80a4-327ac862a051.png"/>
</div>

## Results and Models

### Cascade Mask R-CNN

| Backbone | Pre-train | Lr schd | box AP | mask AP | #Param | Config | Download |
|:-------------:|:---------------------------------------------------------------------------------:|:-------:|:------:|:-------:|:------:|:---------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------:|
| ViT-Adapter-S | [DeiT-S](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | 3x | 51.5 | 44.5 | 86M | [config](./cascade_mask_rcnn_deit_adapter_small_fpn_3x_coco.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.3/cascade_mask_rcnn_deit_adapter_small_fpn_3x_coco.pth.tar) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Shanghai AI Lab. All rights reserved.
_base_ = [
'../_base_/models/cascade_mask_rcnn_r50_fpn.py',
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_3x.py',
'../_base_/default_runtime.py'
]
# pretrained = 'https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'
pretrained = 'pretrained/deit_small_patch16_224-cd65a155_.pth'
model = dict(
backbone=dict(
_delete_=True,
type='ViTAdapter',
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
drop_path_rate=0.2,
conv_inplane=64,
n_points=4,
deform_num_heads=6,
cffn_ratio=0.25,
deform_ratio=1.0,
interaction_indexes=[[0, 2], [3, 5], [6, 8], [9, 11]],
window_attn=[True, True, False, True, True, False,
True, True, False, True, True, False],
window_size=[14, 14, None, 14, 14, None,
14, 14, None, 14, 14, None],
pretrained=pretrained),
neck=dict(
type='FPN',
in_channels=[384, 384, 384, 384],
out_channels=256,
num_outs=5),
roi_head=dict(
bbox_head=[
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
dict(
type='ConvFCBBoxHead',
num_shared_convs=4,
num_shared_fcs=1,
in_channels=256,
conv_out_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=False,
reg_decoded_bbox=True,
norm_cfg=dict(type='SyncBN', requires_grad=True),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
]))

# optimizer
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# augmentation strategy originates from DETR / Sparse RCNN
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='AutoAugment',
policies=[
[
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]
]),
dict(type='RandomCrop',
crop_type='absolute_range',
crop_size=(1024, 1024),
allow_negative_crop=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))
optimizer = dict(
_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'level_embed': dict(decay_mult=0.),
'pos_embed': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'bias': dict(decay_mult=0.)
}))
optimizer_config = dict(grad_clip=None)
fp16 = dict(loss_scale=dict(init_scale=512))
4 changes: 2 additions & 2 deletions detection/configs/mask_rcnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Mask R-CNN is a conceptually simple, flexible, and general framework for object
| Backbone | Pre-train | Lr schd | box AP | mask AP | #Param | Config | Download |
|:-------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-------:|:------:|:-------:|:------:|:--------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------:|
| ViT-Adapter-T | [DeiT-T](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) | 3x | 46.0 | 41.0 | 28M | [config](./mask_rcnn_deit_adapter_tiny_fpn_3x_coco.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.2/mask_rcnn_deit_adapter_tiny_fpn_3x_coco.pth.tar) |
| ViT-Adapter-S | [DeiT-S](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | 3x | 48.2 | 42.8 | 48M | [config](./mask_rcnn_deit_adapter_small_fpn_3x_coco.py) | [model]() |
| ViT-Adapter-S | [DeiT-S](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | 3x | 48.2 | 42.8 | 48M | [config](./mask_rcnn_deit_adapter_small_fpn_3x_coco.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.2/mask_rcnn_deit_adapter_small_fpn_3x_coco.pth.tar) |
| ViT-Adapter-B | [DeiT-B](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) | 3x | 49.6 | 43.6 | 120M | [config](./mask_rcnn_deit_adapter_base_fpn_3x_coco.py) | [model]() |
| ViT-Adapter-B | [Uni-Perceiver](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.1/uniperceiver_pretrain.pth) | 3x | 50.7 | 44.9 | 120M | [config](./mask_rcnn_uniperceiver_adapter_base_fpn_3x_coco.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.1/mask_rcnn_uniperceiver_adapter_base_fpn_3x_coco.pth.tar) |
| ViT-Adapter-L | [AugReg](https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz) | 3x | 50.9 | 44.8 | 348M | [config](./mask_rcnn_augreg_adapter_large_fpn_3x_coco.py) | [model]() |
| ViT-Adapter-L | [AugReg](https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz) | 3x | 50.9 | 44.8 | 348M | [config](./mask_rcnn_augreg_adapter_large_fpn_3x_coco.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.1.2/mask_rcnn_augreg_adapter_large_fpn_3x_coco.pth.tar) |
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
num_outs=5,
norm_cfg=dict(type='MMSyncBN', requires_grad=True)),
norm_cfg=dict(type='MMSyncBN', requires_grad=True)), # BN can be removed
roi_head=dict(
bbox_head=dict(norm_cfg=dict(type='MMSyncBN', requires_grad=True)),
mask_head=dict(norm_cfg=dict(type='MMSyncBN', requires_grad=True)))
bbox_head=dict(norm_cfg=dict(type='MMSyncBN', requires_grad=True)), # BN can be removed
mask_head=dict(norm_cfg=dict(type='MMSyncBN', requires_grad=True))) # BN can be removed
)
# optimizer
img_norm_cfg = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
n_points=4,
deform_num_heads=6,
cffn_ratio=0.25,
deform_ratio=1.0,
interaction_indexes=[[0, 2], [3, 5], [6, 8], [9, 11]],
window_attn=[True, True, False, True, True, False,
True, True, False, True, True, False],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# 'releases/download/v0.1.1/uniperceiver_pretrain.pth'
pretrained = 'pretrained/uniperceiver_pretrain.pth'
model = dict(
type='MaskRCNN',
backbone=dict(
_delete_=True,
type='UniPerceiverAdapter',
Expand Down

0 comments on commit b8857ea

Please sign in to comment.