Skip to content

Commit

Permalink
training_code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiemin.fang committed Jan 3, 2020
1 parent a33c143 commit 1c7da1c
Show file tree
Hide file tree
Showing 24 changed files with 945 additions and 156 deletions.
28 changes: 23 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# DenseNAS
The evaluation code of the paper [Densely Connected Search Space for More Flexible Neural Architecture Search](https://arxiv.org/abs/1906.09607)

The code of the paper [Densely Connected Search Space for More Flexible Neural Architecture Search](https://arxiv.org/abs/1906.09607)

Neural architecture search (NAS) has dramatically advanced the development of neural network design. We revisit the search space design in most previous NAS methods and find the number of blocks and the widths of blocks are set manually. However, block counts and block widths determine the network scale (depth and width) and make a great influence on both the accuracy and the model cost (FLOPs/latency).

We propose to search block counts and block widths by designing a densely connected search space, i.e., DenseNAS. The new search space is represented as a dense super network, which is built upon our designed routing blocks. In the super network, routing blocks are densely connected and we search for the best path between them to derive the final architecture. We further propose a chained cost estimation algorithm to approximate the model cost during the search. Both the accuracy and model cost are optimized in DenseNAS.
![](./imgs/search_space.png)
![search_space](./imgs/search_space.png)


## Requirements

* pytorch 1.0.1
* python 3.6+

## Results

For experiments on the MobileNetV2-based search space, DenseNAS achieves 75.3\% top-1 accuracy on ImageNet with only 361MB FLOPs and 17.9ms latency on a single TITAN-XP. The larger model searched by DenseNAS achieves 76.1\% accuracy with only 479M FLOPs. DenseNAS further promotes the ImageNet classification accuracies of ResNet-18, -34 and -50-B by 1.5\%, 0.5\% and 0.3\% with 200M, 600M and 680M FLOPs reduction respectively.

The comparison of model performance on ImageNet under the MobileNetV2-based search spaces.
Expand All @@ -27,7 +31,7 @@ The comparison of model performance on ImageNet under the ResNet-based search sp
<img src="imgs/res_comp.png" width="40%">
</p>

Our pre-trained models can be downloaded in the following links:
Our pre-trained models can be downloaded in the following links. The complete list of the models can be found in [DenseNAS_modelzoo](https://drive.google.com/open?id=183oIMF6IowZrj81kenVBkQoIMlip9kLo).

| Model | FLOPs | Latency | Top-1(%)|
|----------------------|-------|---------|---------|
Expand All @@ -39,8 +43,22 @@ Our pre-trained models can be downloaded in the following links:
| [DenseNAS-R2](https://drive.google.com/open?id=1Qawst3E2hqdam2TiTFo2BhBXS-M6AWdh) | 3.06B | 22.2ms | 75.8 |
| [DenseNAS-R3](https://drive.google.com/open?id=14RwIGWsurNvevhxL9AcnlngU0KR8WeX-) | 3.41B | 41.7ms | 78.0 |

![](imgs/archs.png)
![args](imgs/archs.png)

## Train

1. (Optional) We pack the ImageNet data as the lmdb file for faster IO. The lmdb files can be made as follows. If you don't want to use lmdb data, just set `__C.data.train_data_type='img'` in the training config file `imagenet_train_cfg.py`.

1). Generate the list of the image data.<br>
`python dataset/mk_img_list.py --image_path 'the path of your image data' --output_path 'the path to output the lmdb file'`

2). Use the image list obtained above to make the lmdb file.<br>
`python dataset/img2lmdb.py --image_path 'the path of your image data' --list_path 'the path of your image list' --output_path 'the path to output the lmdb file' --split 'split folder (train/val)'`

2. Train the model with the following script. You can also train your customized model by redefine `model` in `retrain.py`.<br>
`python -m run_apis.retrain --data_path 'The path of ImageNet data' --load_path 'The path you put the net_config of the model'`

## Evaluate

1. Download the related files of the pretrained model and put `net_config` and `weights.pt` into the `model_path`
2. `python validation.py --data_path 'The path of ImageNet data' --load_path 'The path you put the pre-trained model'`
2. `python -m run_apis.validation --data_path 'The path of ImageNet data' --load_path 'The path you put the pre-trained model'`
Empty file modified configs/__init__.py
100755 → 100644
Empty file.
62 changes: 62 additions & 0 deletions configs/imagenet_train_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from tools.collections import AttrDict

__C = AttrDict()

cfg = __C

__C.net_type='mbv2' # mbv2 / res
__C.net_config=""

__C.train_params=AttrDict()
__C.train_params.epochs=240
__C.train_params.use_seed=False
__C.train_params.seed=0

__C.optim=AttrDict()
__C.optim.init_lr=0.5
__C.optim.min_lr=1e-5
__C.optim.lr_schedule='cosine' # cosine poly
__C.optim.momentum=0.9
__C.optim.weight_decay=4e-5
__C.optim.use_grad_clip=False
__C.optim.grad_clip=10
__C.optim.label_smooth=True
__C.optim.smooth_alpha=0.1

__C.optim.if_resume=False
__C.optim.resume=AttrDict()
__C.optim.resume.load_path=''
__C.optim.resume.load_epoch=0

__C.data=AttrDict()
__C.data.num_workers=16
__C.data.batch_size=256
__C.data.dataset='imagenet' #imagenet
__C.data.train_data_type='lmdb'
__C.data.val_data_type='img'
__C.data.patch_dataset=False
__C.data.num_examples=1281167
__C.data.input_size=(3,224,224)
__C.data.type_of_data_aug='random_sized' # random_sized / rand_scale
__C.data.random_sized=AttrDict()
__C.data.random_sized.min_scale=0.08
__C.data.mean=[0.485, 0.456, 0.406]
__C.data.std=[0.229, 0.224, 0.225]
__C.data.color=False
__C.data.lighting=False

__C.optim.use_warm_up=False
__C.optim.warm_up=AttrDict()
__C.optim.warm_up.epoch=5
__C.optim.warm_up.init_lr=0.0001
__C.optim.warm_up.target_lr=0.1

__C.optim.use_multi_stage=False
__C.optim.multi_stage=AttrDict()
__C.optim.multi_stage.stage_epochs=330

__C.optim.cosine=AttrDict()
__C.optim.cosine.use_restart=False
__C.optim.cosine.restart=AttrDict()
__C.optim.cosine.restart.lr_period=[10, 20, 40, 80, 160, 320]
__C.optim.cosine.restart.lr_step=[0, 10, 30, 70, 150, 310]
File renamed without changes.
Empty file added dataset/__init__.py
Empty file.
135 changes: 111 additions & 24 deletions dataset/imagenet_data.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

import torch
import torchvision
import torchvision.transforms as vision_transforms
import torchvision.transforms as transforms

from dataset import lmdb_dataset
from dataset import torchvision_extension as vision_transforms_extension

meanstd = {
'mean':[0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
}
from . import lmdb_dataset
from . import torchvision_extension as transforms_extension
from .prefetch_data import fast_collate

pca = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
Expand All @@ -25,14 +21,15 @@
class ImageNet12(object):

def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True,
size_images=224, scaled_size=256, data_config=None):
size_images=224, scaled_size=256, type_of_data_augmentation='rand_scale',
data_config=None):

self.data_config = data_config
self.trainFolder = trainFolder
self.testFolder = testFolder
self.num_workers = num_workers
self.pin_memory = pin_memory
self.patch_dataset = self.data_config.patch_dataset
self.meanstd = meanstd
self.pca = pca

#images will be rescaled to match this size
Expand All @@ -41,44 +38,134 @@ def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True,
self.size_images = size_images
self.scaled_size = scaled_size

type_of_data_augmentation = type_of_data_augmentation.lower()
if type_of_data_augmentation not in ('rand_scale', 'random_sized'):
raise ValueError('type_of_data_augmentation must be either rand-scale or random-sized')
self.type_of_data_augmentation = type_of_data_augmentation


def _getTransformList(self, aug_type):

assert aug_type in ['rand_scale', 'random_sized', 'week_train', 'validation']
list_of_transforms = []

if aug_type == 'validation':
list_of_transforms.append(vision_transforms.Resize(self.scaled_size))
list_of_transforms.append(vision_transforms.CenterCrop(self.size_images))
list_of_transforms.append(vision_transforms.ToTensor())
list_of_transforms.append(vision_transforms.Normalize(mean=self.meanstd['mean'],
std=self.meanstd['std']))

return vision_transforms.Compose(list_of_transforms)
list_of_transforms.append(transforms.Resize(self.scaled_size))
list_of_transforms.append(transforms.CenterCrop(self.size_images))

elif aug_type == 'week_train':
list_of_transforms.append(transforms.Resize(256))
list_of_transforms.append(transforms.RandomCrop(self.size_images))
list_of_transforms.append(transforms.RandomHorizontalFlip())

else:
if aug_type == 'rand_scale':
list_of_transforms.append(transforms_extension.RandomScale(256, 480))
list_of_transforms.append(transforms.RandomCrop(self.size_images))
list_of_transforms.append(transforms.RandomHorizontalFlip())

elif aug_type == 'random_sized':
list_of_transforms.append(transforms.RandomResizedCrop(self.size_images,
scale=(self.data_config.random_sized.min_scale, 1.0)))
list_of_transforms.append(transforms.RandomHorizontalFlip())

if self.data_config.color:
list_of_transforms.append(transforms.ColorJitter(brightness=0.4,
contrast=0.4,
saturation=0.4))
if self.data_config.lighting:
list_of_transforms.append(transforms_extension.Lighting(alphastd=0.1,
eigval=self.pca['eigval'],
eigvec=self.pca['eigvec']))
return transforms.Compose(list_of_transforms)


def _getTrainSet(self):

train_transform = self._getTransformList(self.type_of_data_augmentation)

if self.data_config.train_data_type == 'img':
train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform)
elif self.data_config.train_data_type == 'lmdb':
train_set = lmdb_dataset.ImageFolder(self.trainFolder,
os.path.join(self.trainFolder, '..', 'train_datalist'),
train_transform,
patch_dataset=self.patch_dataset)
self.train_num_examples = train_set.__len__()

return train_set


def _getWeekTrainSet(self):

train_transform = self._getTransformList('week_train')
if self.data_config.train_data_type == 'img':
train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform)
elif self.data_config.train_data_type == 'lmdb':
train_set = lmdb_dataset.ImageFolder(self.trainFolder,
os.path.join(self.trainFolder, '..', 'train_datalist'),
train_transform,
patch_dataset=self.patch_dataset)
self.train_num_examples = train_set.__len__()
return train_set


def _getTestSet(self):
# first we define the training transform we will apply to the dataset

test_transform = self._getTransformList('validation')

if self.data_config.val_data_type == 'img':
test_set = torchvision.datasets.ImageFolder(self.testFolder, test_transform)
elif self.data_config.val_data_type == 'lmdb':
test_set = lmdb_dataset.ImageFolder(self.testFolder,
os.path.join(self.testFolder, '..', 'val_datalist'),
test_transform)
self.test_num_examples = test_set.__len__()

return test_set


def getTestLoader(self, batch_size, shuffle=False):
def getTrainLoader(self, batch_size, shuffle=True):

test_set = self._getTestSet()
test_loader = torch.utils.data.DataLoader(test_set,
train_set = self._getTrainSet()
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=shuffle,
num_workers=self.num_workers, pin_memory=self.pin_memory,
sampler=None, collate_fn=fast_collate)
return train_loader


def getWeekTrainLoader(self, batch_size, shuffle=True):

train_set = self._getWeekTrainSet()
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=self.pin_memory)
pin_memory=self.pin_memory,
collate_fn=fast_collate)
return train_loader


def getTestLoader(self, batch_size, shuffle=False):

test_set = self._getTestSet()

test_loader = torch.utils.data.DataLoader(
test_set, batch_size=batch_size, shuffle=shuffle,
num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None,
collate_fn=fast_collate)
return test_loader


def getTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False):

train_loader = self.getTrainLoader(batch_size, train_shuffle)
test_loader = self.getTestLoader(batch_size, val_shuffle)
return train_loader, test_loader


def getSetTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False):

train_loader = self.getTrainLoader(batch_size, train_shuffle)
week_train_loader = self.getWeekTrainLoader(batch_size, train_shuffle)
test_loader = self.getTestLoader(batch_size, val_shuffle)
return (train_loader, week_train_loader), test_loader
Loading

0 comments on commit 1c7da1c

Please sign in to comment.