Pytorch implementation for Classification, Semantic Segmentation, Pose Estimation and Object Detection
- Image Classification
- Semantic Segmentation (progressing...)
- Object Detection (progressing...)
- Pose Estimation (progressing...)
- Install PyTorch>=0.3.0
- Install torchvision>=0.2.0
- Clone recursively
git clone --recursive https://github.com/soeaver/pytorch-priv
- pip install easydict
For training:
-
Modify the
.yml
file in./cfg/imagenet/air50-1x64d
:- the
ckpt
is used to save the checkpoints - if you want use cosine learning rate, please set
cosine_lr: True
, thenlr_schedule
andgamma
will not be used - for resuming training, add the
model.pth.tar
toresume:
and modifystart_epoch
rotation
,pixel_jitter
andgrayscale
are extra data augmentation, recommended for training complex networks only
- the
-
Train a network:
python cls_train.py --cfg ./cfg/imagenet/air50_1x64d.yml
2.1 Training with mixup (optional):
python tools/cls_mixup_train.py --cfg ./cfg/imagenet/air50_1x64d_mixup.yml
for better performace:
- double the epochs for training with mixup
- a few extra epochs with no mixup after the process above
2.2 Ttraining cifar dataset (optional):
python tools/cls_cifar.py --cfg ./cfg/cifar10/resnext29_8x64d.yml
or with mixup (usually
weight_decay: 0.0001
):python tools/cls_mixup_cifar.py --cfg ./cfg/cifar10/resnext29_8x64d_mixup.yml
For evaluating:
-
Modify the
.yml
file in./cfg/imagenet/air50-1x64d
:- add the
model.pth.tar
topretrained:
- set the
evaluate: True
- add the
-
Evaluate a network:
python train_cls.py --cfg ./cfg/imagenet/air50_1x64d.yml
For evaluating image by image:
-
Modify the
tools/cls_eval.py
file -
Evaluate a network:
python tools/cls_eval.py
Single-crop (224x224) validation error rate is reported.
Network | Flops (M) | Params (M) | Top-1 Error (%) | Top-5 Error (%) | Speed (im/sec) |
---|---|---|---|---|---|
resnet50-1x64d | 4342.1 | 25.5 | 23.52 | 7.01 | 157.1 |
resnet101-1x64d | 8039.0 | 44.5 | 22.18 | 6.23 | 91.7 |
- Speed test on single Titan xp GPU with
batch_size: 1
.
Validation error rate is reported.
Network | Flops (M) | Params (M) | Cifar10 Top-1 Error (%) |
Cifar100 Top-1 Error (%) |
---|---|---|---|---|
resnext29-8x64d | 5387.2 | 34.4 | 3.73 | 18.55 |
resnext29-8x64d-mixup | 5387.2 | 34.4 | 2.90 | -- |
resnext29-8x64d-re | 5387.2 | 34.4 | 3.55 | -- |
pytorch-priv is released under the MIT License (refer to the LICENSE file for details).
Feel free to create a pull request if you find any bugs or you want to contribute (e.g., more datasets and more network structures).