From 5a25c8b3b70956174b3c1876321967e265b2da42 Mon Sep 17 00:00:00 2001 From: lizhiqi Date: Mon, 27 Jun 2022 10:43:49 +0800 Subject: [PATCH] support fp16 and batchsize>1 --- docs/getting_started.md | 8 + .../bevformer_fp16/bevformer_tiny_fp16.py | 272 ++++++++++++++++++ projects/mmdet3d_plugin/bevformer/__init__.py | 2 + .../bevformer/apis/mmdet_train.py | 40 ++- .../mmdet3d_plugin/bevformer/apis/train.py | 2 + .../bevformer/detectors/__init__.py | 3 +- .../bevformer/detectors/bevformer.py | 2 +- .../bevformer/detectors/bevformer_fp16.py | 89 ++++++ .../bevformer/hooks/__init__.py | 1 + .../bevformer/hooks/custom_hooks.py | 14 + .../bevformer/modules/transformer.py | 3 + .../bevformer/runner/__init__.py | 1 + .../bevformer/runner/epoch_based_runner.py | 96 +++++++ tools/fp16/dist_train.sh | 9 + tools/fp16/train.py | 271 +++++++++++++++++ 15 files changed, 801 insertions(+), 12 deletions(-) create mode 100644 projects/configs/bevformer_fp16/bevformer_tiny_fp16.py create mode 100644 projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py create mode 100644 projects/mmdet3d_plugin/bevformer/hooks/__init__.py create mode 100644 projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py create mode 100644 projects/mmdet3d_plugin/bevformer/runner/__init__.py create mode 100644 projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py create mode 100755 tools/fp16/dist_train.sh create mode 100644 tools/fp16/train.py diff --git a/docs/getting_started.md b/docs/getting_started.md index 8a3aa30..bdcba68 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -16,3 +16,11 @@ Eval BEVFormer with 8 GPUs Note: using 1 GPU to eval can obtain slightly higher performance because continuous video may be truncated with multiple GPUs. By default we report the score evaled with 8 GPUs. + +# Using FP16 to train the model. + +We provide another script to train BEVFormer with FP16. + +``` +./tools/fp16/dist_train.sh ./projects/configs/bevformer_fp16/bevformer_tiny_fp16.py 8 +``` \ No newline at end of file diff --git a/projects/configs/bevformer_fp16/bevformer_tiny_fp16.py b/projects/configs/bevformer_fp16/bevformer_tiny_fp16.py new file mode 100644 index 0000000..aa1e043 --- /dev/null +++ b/projects/configs/bevformer_fp16/bevformer_tiny_fp16.py @@ -0,0 +1,272 @@ +# BEvFormer-tiny consumes at lease 6700M GPU memory +# compared to bevformer_base, bevformer_tiny has +# smaller backbone: R101-DCN -> R50 +# smaller BEV: 200*200 -> 50*50 +# less encoder layers: 6 -> 3 +# smaller input size: 1600*900 -> 800*450 +# multi-scale feautres -> single scale features (C5) + + +_base_ = [ + '../datasets/custom_nus-3d.py', + '../_base_/default_runtime.py' +] +# +plugin = True +plugin_dir = 'projects/mmdet3d_plugin/' + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +voxel_size = [0.2, 0.2, 8] + + + + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=True) + +_dim_ = 256 +_pos_dim_ = _dim_//2 +_ffn_dim_ = _dim_*2 +_num_levels_ = 1 +bev_h_ = 50 +bev_w_ = 50 +queue_length = 3 # each sequence contains `queue_length` frames. + +model = dict( + type='BEVFormer_fp16', + use_grid_mask=True, + video_test_mode=True, + pretrained=dict(img='torchvision://resnet50'), + img_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(3,), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + img_neck=dict( + type='FPN', + in_channels=[2048], + out_channels=_dim_, + start_level=0, + add_extra_convs='on_output', + num_outs=_num_levels_, + relu_before_extra_convs=True), + pts_bbox_head=dict( + type='BEVFormerHead', + bev_h=bev_h_, + bev_w=bev_w_, + num_query=900, + num_classes=10, + in_channels=_dim_, + sync_cls_avg_factor=True, + with_box_refine=True, + as_two_stage=False, + transformer=dict( + type='PerceptionTransformer', + rotate_prev_bev=True, + use_shift=True, + use_can_bus=True, + embed_dims=_dim_, + encoder=dict( + type='BEVFormerEncoder', + num_layers=3, + pc_range=point_cloud_range, + num_points_in_pillar=4, + return_intermediate=False, + transformerlayers=dict( + type='BEVFormerLayer', + attn_cfgs=[ + dict( + type='TemporalSelfAttention', + embed_dims=_dim_, + num_levels=1), + dict( + type='SpatialCrossAttention', + pc_range=point_cloud_range, + deformable_attention=dict( + type='MSDeformableAttention3D', + embed_dims=_dim_, + num_points=8, + num_levels=_num_levels_), + embed_dims=_dim_, + ) + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm'))), + decoder=dict( + type='DetectionTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + dict( + type='CustomMSDeformableAttention', + embed_dims=_dim_, + num_levels=1), + ], + + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')))), + bbox_coder=dict( + type='NMSFreeCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + pc_range=point_cloud_range, + max_num=300, + voxel_size=voxel_size, + num_classes=10), + positional_encoding=dict( + type='LearnedPositionalEncoding', + num_feats=_pos_dim_, + row_num_embed=bev_h_, + col_num_embed=bev_w_, + ), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_bbox=dict(type='L1Loss', loss_weight=0.25), + loss_iou=dict(type='GIoULoss', loss_weight=0.0)), + # model training and testing settings + train_cfg=dict(pts=dict( + grid_size=[512, 512, 1], + voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + out_size_factor=4, + assigner=dict( + type='HungarianAssigner3D', + cls_cost=dict(type='FocalLossCost', weight=2.0), + reg_cost=dict(type='BBox3DL1Cost', weight=0.25), + iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head. + pc_range=point_cloud_range)))) + +dataset_type = 'CustomNuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + + +train_pipeline = [ + dict(type='LoadMultiViewImageFromFiles', to_float32=True), + dict(type='PhotoMetricDistortionMultiViewImage'), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='NormalizeMultiviewImage', **img_norm_cfg), + dict(type='RandomScaleImageMultiViewImage', scales=[0.5]), + dict(type='PadMultiViewImage', size_divisor=32), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='CustomCollect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img']) +] + +test_pipeline = [ + dict(type='LoadMultiViewImageFromFiles', to_float32=True), + dict(type='NormalizeMultiviewImage', **img_norm_cfg), + + dict( + type='MultiScaleFlipAug3D', + img_scale=(1600, 900), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict(type='RandomScaleImageMultiViewImage', scales=[0.5]), + dict(type='PadMultiViewImage', size_divisor=32), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='CustomCollect3D', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=8, + train=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_temporal_train.pkl', + pipeline=train_pipeline, + classes=class_names, + modality=input_modality, + test_mode=False, + use_valid_flag=True, + bev_size=(bev_h_, bev_w_), + queue_length=queue_length, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR'), + val=dict(type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_temporal_val.pkl', + pipeline=test_pipeline, bev_size=(bev_h_, bev_w_), + classes=class_names, modality=input_modality, samples_per_gpu=1), + test=dict(type=dataset_type, + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_temporal_val.pkl', + pipeline=test_pipeline, bev_size=(bev_h_, bev_w_), + classes=class_names, modality=input_modality), + shuffler_sampler=dict(type='DistributedGroupSampler'), + nonshuffler_sampler=dict(type='DistributedSampler') +) + +optimizer = dict( + type='AdamW', + lr=2.8e-4, + paramwise_cfg=dict( + custom_keys={ + 'img_backbone': dict(lr_mult=0.1), + }), + weight_decay=0.01) + +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='CosineAnnealing', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3) +total_epochs = 24 +evaluation = dict(interval=1, pipeline=test_pipeline) + +runner = dict(type='EpochBasedRunner_video', max_epochs=total_epochs) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) + +fp16 = dict(loss_scale=512.) +checkpoint_config = dict(interval=1) +custom_hooks = [dict(type='TransferWeight',priority='LOWEST')] \ No newline at end of file diff --git a/projects/mmdet3d_plugin/bevformer/__init__.py b/projects/mmdet3d_plugin/bevformer/__init__.py index 745867b..98d6e7e 100644 --- a/projects/mmdet3d_plugin/bevformer/__init__.py +++ b/projects/mmdet3d_plugin/bevformer/__init__.py @@ -2,3 +2,5 @@ from .dense_heads import * from .detectors import * from .modules import * +from .runner import * +from .hooks import * diff --git a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py index 813643a..e57bd22 100644 --- a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py +++ b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py @@ -31,6 +31,7 @@ def custom_train_detector(model, distributed=False, validate=False, timestamp=None, + eval_model=None, meta=None): logger = get_root_logger(cfg.log_level) @@ -76,10 +77,19 @@ def custom_train_detector(model, device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) - + if eval_model is not None: + eval_model = MMDistributedDataParallel( + eval_model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) else: model = MMDataParallel( model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + if eval_model is not None: + eval_model = MMDataParallel( + eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + # build runner optimizer = build_optimizer(model, cfg.optimizer) @@ -95,15 +105,25 @@ def custom_train_detector(model, else: if 'total_epochs' in cfg: assert cfg.total_epochs == cfg.runner.max_epochs - - runner = build_runner( - cfg.runner, - default_args=dict( - model=model, - optimizer=optimizer, - work_dir=cfg.work_dir, - logger=logger, - meta=meta)) + if eval_model is not None: + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + eval_model=eval_model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + else: + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp diff --git a/projects/mmdet3d_plugin/bevformer/apis/train.py b/projects/mmdet3d_plugin/bevformer/apis/train.py index 730d41a..f9391e6 100644 --- a/projects/mmdet3d_plugin/bevformer/apis/train.py +++ b/projects/mmdet3d_plugin/bevformer/apis/train.py @@ -14,6 +14,7 @@ def custom_train_model(model, distributed=False, validate=False, timestamp=None, + eval_model=None, meta=None): """A function wrapper for launching model training according to cfg. @@ -30,6 +31,7 @@ def custom_train_model(model, distributed=distributed, validate=validate, timestamp=timestamp, + eval_model=eval_model, meta=meta) diff --git a/projects/mmdet3d_plugin/bevformer/detectors/__init__.py b/projects/mmdet3d_plugin/bevformer/detectors/__init__.py index dda36d9..4c39fd3 100644 --- a/projects/mmdet3d_plugin/bevformer/detectors/__init__.py +++ b/projects/mmdet3d_plugin/bevformer/detectors/__init__.py @@ -1 +1,2 @@ -from .bevformer import BEVFormer \ No newline at end of file +from .bevformer import BEVFormer +from .bevformer_fp16 import BEVFormer_fp16 \ No newline at end of file diff --git a/projects/mmdet3d_plugin/bevformer/detectors/bevformer.py b/projects/mmdet3d_plugin/bevformer/detectors/bevformer.py index acc9af6..2a058ef 100644 --- a/projects/mmdet3d_plugin/bevformer/detectors/bevformer.py +++ b/projects/mmdet3d_plugin/bevformer/detectors/bevformer.py @@ -175,7 +175,7 @@ def obtain_history_bev(self, imgs_queue, img_metas_list): self.train() return prev_bev - @auto_fp16(apply_to=('img', 'prev_bev', 'points')) + @auto_fp16(apply_to=('img', 'points')) def forward_train(self, points=None, img_metas=None, diff --git a/projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py b/projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py new file mode 100644 index 0000000..5325e3c --- /dev/null +++ b/projects/mmdet3d_plugin/bevformer/detectors/bevformer_fp16.py @@ -0,0 +1,89 @@ +# --------------------------------------------- +# Copyright (c) OpenMMLab. All rights reserved. +# --------------------------------------------- +# Modified by Zhiqi Li +# --------------------------------------------- + +from tkinter.messagebox import NO +import torch +from mmcv.runner import force_fp32, auto_fp16 +from mmdet.models import DETECTORS +from mmdet3d.core import bbox3d2result +from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector +from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask +from projects.mmdet3d_plugin.bevformer.detectors.bevformer import BEVFormer +import time +import copy +import numpy as np +import mmdet3d +from projects.mmdet3d_plugin.models.utils.bricks import run_time + + +@DETECTORS.register_module() +class BEVFormer_fp16(BEVFormer): + """ + The default version BEVFormer currently can not support FP16. + We provide this version to resolve this issue. + """ + + @auto_fp16(apply_to=('img', 'prev_bev', 'points')) + def forward_train(self, + points=None, + img_metas=None, + gt_bboxes_3d=None, + gt_labels_3d=None, + gt_labels=None, + gt_bboxes=None, + img=None, + proposals=None, + gt_bboxes_ignore=None, + img_depth=None, + img_mask=None, + prev_bev=None, + ): + """Forward training function. + Args: + points (list[torch.Tensor], optional): Points of each sample. + Defaults to None. + img_metas (list[dict], optional): Meta information of each sample. + Defaults to None. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): + Ground truth 3D boxes. Defaults to None. + gt_labels_3d (list[torch.Tensor], optional): Ground truth labels + of 3D boxes. Defaults to None. + gt_labels (list[torch.Tensor], optional): Ground truth labels + of 2D boxes in images. Defaults to None. + gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in + images. Defaults to None. + img (torch.Tensor optional): Images of each sample with shape + (N, C, H, W). Defaults to None. + proposals ([list[torch.Tensor], optional): Predicted proposals + used for training Fast RCNN. Defaults to None. + gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth + 2D boxes in images to be ignored. Defaults to None. + Returns: + dict: Losses of different branches. + """ + + img_feats = self.extract_feat(img=img, img_metas=img_metas) + + losses = dict() + losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, + gt_labels_3d, img_metas, + gt_bboxes_ignore, prev_bev=prev_bev) + losses.update(losses_pts) + return losses + + + def val_step(self, data, optimizer): + """ + In BEVFormer_fp16, we use this `val_step` function to inference the `prev_pev`. + This is not the standard function of `val_step`. + """ + + img = data['img'] + img_metas = data['img_metas'] + img_feats = self.extract_feat(img=img, img_metas=img_metas) + prev_bev = data.get('prev_bev', None) + prev_bev = self.pts_bbox_head(img_feats, img_metas, prev_bev=prev_bev, only_bev=True) + return prev_bev \ No newline at end of file diff --git a/projects/mmdet3d_plugin/bevformer/hooks/__init__.py b/projects/mmdet3d_plugin/bevformer/hooks/__init__.py new file mode 100644 index 0000000..aa04ec1 --- /dev/null +++ b/projects/mmdet3d_plugin/bevformer/hooks/__init__.py @@ -0,0 +1 @@ +from .custom_hooks import TransferWeight \ No newline at end of file diff --git a/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py b/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py new file mode 100644 index 0000000..091738a --- /dev/null +++ b/projects/mmdet3d_plugin/bevformer/hooks/custom_hooks.py @@ -0,0 +1,14 @@ +from mmcv.runner.hooks.hook import HOOKS, Hook +from projects.mmdet3d_plugin.models.utils import run_time + + +@HOOKS.register_module() +class TransferWeight(Hook): + + def __init__(self, every_n_inters=1): + self.every_n_inters=every_n_inters + + def after_train_iter(self, runner): + if self.every_n_inner_iters(runner, self.every_n_inters): + runner.eval_model.load_state_dict(runner.model.state_dict()) + diff --git a/projects/mmdet3d_plugin/bevformer/modules/transformer.py b/projects/mmdet3d_plugin/bevformer/modules/transformer.py index a27c5cc..b740fcc 100644 --- a/projects/mmdet3d_plugin/bevformer/modules/transformer.py +++ b/projects/mmdet3d_plugin/bevformer/modules/transformer.py @@ -20,6 +20,7 @@ from .spatial_cross_attention import MSDeformableAttention3D from .decoder import CustomMSDeformableAttention from projects.mmdet3d_plugin.models.utils.bricks import run_time +from mmcv.runner import force_fp32, auto_fp16 @TRANSFORMER.register_module() @@ -99,6 +100,7 @@ def init_weights(self): xavier_init(self.reference_points, distribution='uniform', bias=0.) xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.) + @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'prev_bev', 'bev_pos')) def get_bev_features( self, mlvl_feats, @@ -197,6 +199,7 @@ def get_bev_features( return bev_embed + @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos')) def forward(self, mlvl_feats, bev_queries, diff --git a/projects/mmdet3d_plugin/bevformer/runner/__init__.py b/projects/mmdet3d_plugin/bevformer/runner/__init__.py new file mode 100644 index 0000000..03f906c --- /dev/null +++ b/projects/mmdet3d_plugin/bevformer/runner/__init__.py @@ -0,0 +1 @@ +from .epoch_based_runner import EpochBasedRunner_video \ No newline at end of file diff --git a/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py b/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py new file mode 100644 index 0000000..e73e5e7 --- /dev/null +++ b/projects/mmdet3d_plugin/bevformer/runner/epoch_based_runner.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# --------------------------------------------- +# Modified by Zhiqi Li +# --------------------------------------------- + +import os.path as osp +import torch +import mmcv +from mmcv.runner.base_runner import BaseRunner +from mmcv.runner.epoch_based_runner import EpochBasedRunner +from mmcv.runner.builder import RUNNERS +from mmcv.runner.checkpoint import save_checkpoint +from mmcv.runner.utils import get_host_info +from pprint import pprint +from mmcv.parallel.data_container import DataContainer + + +@RUNNERS.register_module() +class EpochBasedRunner_video(EpochBasedRunner): + + ''' + # basic logic + + input_sequence = [a, b, c] # given a sequence of samples + + prev_bev = None + for each in input_sequcene[:-1] + prev_bev = eval_model(each, prev_bev)) # inference only. + + model(input_sequcene[-1], prev_bev) # train the last sample. + ''' + + def __init__(self, + model, + eval_model=None, + batch_processor=None, + optimizer=None, + work_dir=None, + logger=None, + meta=None, + keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], + max_iters=None, + max_epochs=None): + super().__init__(model, + batch_processor, + optimizer, + work_dir, + logger, + meta, + max_iters, + max_epochs) + keys.append('img_metas') + self.keys = keys + self.eval_model = eval_model + self.eval_model.eval() + + def run_iter(self, data_batch, train_mode, **kwargs): + if self.batch_processor is not None: + assert False + # outputs = self.batch_processor( + # self.model, data_batch, train_mode=train_mode, **kwargs) + elif train_mode: + + num_samples = data_batch['img'].data[0].size(1) + data_list = [] + prev_bev = None + for i in range(num_samples): + data = {} + for key in self.keys: + if key not in ['img_metas', 'img', 'points']: + data[key] = data_batch[key] + else: + if key == 'img': + data['img'] = DataContainer(data=[data_batch['img'].data[0][:, i]], cpu_only=data_batch['img'].cpu_only, stack=True) + elif key == 'img_metas': + data['img_metas'] = DataContainer(data=[[each[i] for each in data_batch['img_metas'].data[0]]], cpu_only=data_batch['img_metas'].cpu_only) + else: + assert False + data_list.append(data) + with torch.no_grad(): + for i in range(num_samples-1): + if i>0: data_list[i]['prev_bev'] = DataContainer(data=[prev_bev], cpu_only=False) + prev_bev = self.eval_model.val_step(data_list[i], self.optimizer, **kwargs) + + data_list[-1]['prev_bev'] = DataContainer(data=[prev_bev], cpu_only=False) + outputs = self.model.train_step(data_list[-1], self.optimizer, **kwargs) + else: + assert False + # outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) + + if not isinstance(outputs, dict): + raise TypeError('"batch_processor()" or "model.train_step()"' + 'and "model.val_step()" must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs \ No newline at end of file diff --git a/tools/fp16/dist_train.sh b/tools/fp16/dist_train.sh new file mode 100755 index 0000000..4ac9a15 --- /dev/null +++ b/tools/fp16/dist_train.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-28508} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} --deterministic diff --git a/tools/fp16/train.py b/tools/fp16/train.py new file mode 100644 index 0000000..eddc349 --- /dev/null +++ b/tools/fp16/train.py @@ -0,0 +1,271 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division + +import argparse +import copy +import mmcv +import os +import time +import torch +import warnings +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist, wrap_fp16_model +from os import path as osp + +from mmdet import __version__ as mmdet_version +from mmdet3d import __version__ as mmdet3d_version +#from mmdet3d.apis import train_model + +from mmdet3d.datasets import build_dataset +from mmdet3d.models import build_model +from mmdet3d.utils import collect_env, get_root_logger +from mmdet.apis import set_random_seed +from mmseg import __version__ as mmseg_version + +from mmcv.utils import TORCH_VERSION, digit_version + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--autoscale-lr', + action='store_true', + help='automatically scale lr with the number of gpus') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both specified, ' + '--options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + + # import modules from plguin/xx, registry will be updated + if hasattr(cfg, 'plugin'): + if cfg.plugin: + import importlib + if hasattr(cfg, 'plugin_dir'): + plugin_dir = cfg.plugin_dir + _module_dir = os.path.dirname(plugin_dir) + _module_dir = _module_dir.split('/') + _module_path = _module_dir[0] + + for m in _module_dir[1:]: + _module_path = _module_path + '.' + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + else: + # import dir is the dirpath for the config file + _module_dir = os.path.dirname(args.config) + _module_dir = _module_dir.split('/') + _module_path = _module_dir[0] + for m in _module_dir[1:]: + _module_path = _module_path + '.' + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + + from projects.mmdet3d_plugin.bevformer.apis import custom_train_model + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + #if args.resume_from is not None: + + if args.resume_from is not None and osp.isfile(args.resume_from): + cfg.resume_from = args.resume_from + + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + if digit_version(TORCH_VERSION) != digit_version('1.8.1'): + cfg.optimizer['type'] = 'AdamW' + if args.autoscale_lr: + # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) + cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + assert False, 'DOT NOT SUPPORT!!!' + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + # specify logger name, if we still use 'mmdet', the output info will be + # filtered and won't be saved in the log_file + # TODO: ugly workaround to judge whether we are training det or seg model + if cfg.model.type in ['EncoderDecoder3D']: + logger_name = 'mmseg' + else: + logger_name = 'mmdet' + logger = get_root_logger( + log_file=log_file, log_level=cfg.log_level, name=logger_name) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + meta['exp_name'] = osp.basename(args.config) + + model = build_model( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + model.init_weights() + + eval_model_config = copy.deepcopy(cfg.model) + eval_model = build_model( + eval_model_config, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(eval_model) + + #eval_model.init_weights() + eval_model.load_state_dict(model.state_dict()) + + logger.info(f'Model:\n{model}') + from projects.mmdet3d_plugin.datasets import custom_build_dataset + datasets = [custom_build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + # in case we use a dataset wrapper + if 'dataset' in cfg.data.train: + val_dataset.pipeline = cfg.data.train.dataset.pipeline + else: + val_dataset.pipeline = cfg.data.train.pipeline + # set test_mode=False here in deep copied config + # which do not affect AP/AR calculation later + # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa + val_dataset.test_mode = False + datasets.append(custom_build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=mmdet_version, + mmseg_version=mmseg_version, + mmdet3d_version=mmdet3d_version, + config=cfg.pretty_text, + CLASSES=datasets[0].CLASSES, + PALETTE=datasets[0].PALETTE # for segmentors + if hasattr(datasets[0], 'PALETTE') else None) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + custom_train_model( + model, + datasets, + cfg, + eval_model=eval_model, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + main()