Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix dtype #15

Open
wants to merge 1 commit into
base: support_codetr
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps
from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps,
bbox_xyxy_to_cxcywh)
from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
OptMultiConfig, reduce_mean)
from ..utils import multi_apply
from ..losses import QualityFocalLoss
from ..utils import multi_apply


@MODELS.register_module()
class DETRHead(BaseModule):
Expand Down Expand Up @@ -424,7 +426,8 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
gt_instances=gt_instances,
img_meta=img_meta)

gt_bboxes = gt_instances.bboxes
# The type of `bboxes` should be consistent with the `cls_score`
gt_bboxes = gt_instances.bboxes.type_as(cls_score)
gt_labels = gt_instances.labels
pos_inds = torch.nonzero(
assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
Expand All @@ -448,8 +451,14 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor,
# DETR regress the relative position of boxes (cxcywh) in the image.
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
pos_gt_bboxes_normalized = pos_gt_bboxes / factor

# `pos_gt_bboxes / factor` will return a float tensor by default.
# Use `type_as` here to make sure the dtype of gt_bboxes is the same as
# the pred_bboxes.
pos_gt_bboxes_normalized = (pos_gt_bboxes / factor).type_as(
bbox_targets)
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
pos_gt_bboxes_targets = pos_gt_bboxes_targets
bbox_targets[pos_inds] = pos_gt_bboxes_targets
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
Expand Down
12 changes: 9 additions & 3 deletions mmdet/models/layers/transformer/dino_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,18 @@ def collate_dn_queries(self, input_label_query: Tensor,
mapper = (batch_idx_expand, map_query_index)

batched_label_query = torch.zeros(
batch_size, num_denoising_queries, self.embed_dims, device=device)
batch_size, num_denoising_queries, self.embed_dims, device=device,
dtype=input_label_query.dtype)
# `input_label_query` is extracted from `nn.Embedding`, of which dtype
# has been converted into the target dtype.
# However the dtype of `batched_label_query` is always `float32`
batched_bbox_query = torch.zeros(
batch_size, num_denoising_queries, 4, device=device)
batch_size, num_denoising_queries, 4, device=device,
dtype=input_label_query.dtype)

batched_label_query[mapper] = input_label_query
batched_bbox_query[mapper] = input_bbox_query
batched_bbox_query[mapper] = input_bbox_query.to(
dtype=input_label_query.dtype)
return batched_label_query, batched_bbox_query

def generate_dn_mask(self, max_num_target: int, num_groups: int,
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/losses/gfocal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def quality_focal_loss(pred, target, beta=2.0):
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
loss[pos, pos_label] = (F.binary_cross_entropy_with_logits(
pred[pos, pos_label], score[pos],
reduction='none') * scale_factor.abs().pow(beta)
reduction='none') * scale_factor.abs().pow(beta)).type_as(loss)

loss = loss.sum(dim=1, keepdim=False)
return loss
Expand Down
46 changes: 26 additions & 20 deletions projects/CO-DETR/codetr/co_dino_head.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Dict, List, Tuple
from torch import Tensor
from mmcv.cnn import Linear
from mmcv.ops import batched_nms, interpolate
from mmengine.structures import InstanceData
from mmcv.ops import batched_nms
from mmdet.utils import InstanceList, reduce_mean
from mmdet.structures import SampleList
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps)
from mmdet.models.utils import multi_apply, unpack_gt_instances
from torch import Tensor

from mmdet.models import DINOHead
from mmdet.models.layers import CdnQueryGenerator
from mmdet.models.layers.transformer import inverse_sigmoid
from mmcv.ops import batched_nms
from mmdet.models.task_modules.samplers import PseudoSampler
from mmcv.cnn import Linear

from mmdet.models.utils import multi_apply, unpack_gt_instances
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps,
bbox_xyxy_to_cxcywh)
from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType,
OptInstanceList, reduce_mean)
from mmdet.models.layers import CdnQueryGenerator
from mmdet.models import DINOHead


@MODELS.register_module()
Expand Down Expand Up @@ -126,10 +126,13 @@ def forward(self,
mlvl_positional_encodings = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
# `positional_encoding` will return a float tensor by
# default. Convert it to the same dtype as `feat` for pure
# bf16/fp16 training
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]))
self.positional_encoding(mlvl_masks[-1]).to(dtype=feat.dtype))

query_embeds = None
hs, inter_references, topk_score, topk_anchor, enc_outputs = \
Expand Down Expand Up @@ -397,10 +400,13 @@ def forward_aux(self, mlvl_feats, img_metas, aux_targets, head_idx):
mlvl_positional_encodings = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
# `positional_encoding` will return a float tensor by
# default. Convert it to the same dtype as `feat` for pure
# bf16/fp16 training
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]))
self.positional_encoding(mlvl_masks[-1]).to(dtype=feat.dtype))

query_embeds = None
hs, inter_references = self.transformer.forward_aux(
Expand Down
23 changes: 16 additions & 7 deletions projects/CO-DETR/codetr/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.model.weight_init import xavier_init
from mmdet.registry import MODELS
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
TransformerLayerSequence,
build_transformer_layer_sequence)
from mmdet.models.layers.transformer import inverse_sigmoid
from mmcv.ops import MultiScaleDeformableAttention
from mmengine.model import BaseModule
from mmengine.model.weight_init import xavier_init
from torch.nn.init import normal_

from mmdet.models.layers.transformer import inverse_sigmoid
from mmdet.registry import MODELS

try:
from fairscale.nn.checkpoint import checkpoint_wrapper
Expand Down Expand Up @@ -308,6 +308,7 @@ def gen_encoder_output_proposals(self, memory, memory_padding_mask,
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
output_proposals = output_proposals.type_as(output_memory)
return output_memory, output_proposals

@staticmethod
Expand Down Expand Up @@ -1034,8 +1035,11 @@ def forward(self,
reference_points_input = \
reference_points[:, :, None] * valid_ratios[:, None]

# `query_sine_embed` will be float by default. Just convert it to
# the same type as `query` to avoid type mismatch when using pure
# bf16/fp16 training
query_sine_embed = self.gen_sineembed_for_position(
reference_points_input[:, :, 0, :], self.embed_dims//2)
reference_points_input[:, :, 0, :], self.embed_dims//2).type_as(query)
query_pos = self.ref_point_head(query_sine_embed)

query_pos = query_pos.permute(1, 0, 2)
Expand Down Expand Up @@ -1262,12 +1266,16 @@ def forward_aux(self,
topk_coords_unact = inverse_sigmoid((pos_anchors))
reference_points = (pos_anchors)
init_reference_out = reference_points

# get_proposal_pos_embed will return a float tensor by default.
# convert it to the same type as `mlvl_feats` to avoid type mismatch
# during pure fp16/bf16 training
if self.num_co_heads > 0:
pos_trans_out = self.aux_pos_trans_norm[head_idx](
self.aux_pos_trans[head_idx](self.get_proposal_pos_embed(topk_coords_unact)))
self.aux_pos_trans[head_idx](self.get_proposal_pos_embed(topk_coords_unact).type_as(mlvl_feats[0])))
query = pos_trans_out
if self.with_coord_feat:
query = query + self.pos_feats_norm[head_idx](self.pos_feats_trans[head_idx](pos_feats))
query = query + self.pos_feats_norm[head_idx](self.pos_feats_trans[head_idx](pos_feats).type_as(mlvl_feats[0]))

# decoder
query = query.permute(1, 0, 2)
Expand All @@ -1292,6 +1300,7 @@ def forward_aux(self,

from mmcv.cnn import build_norm_layer


@MODELS.register_module()
class DetrTransformerEncoder(TransformerLayerSequence):
"""TransformerEncoder of DETR.
Expand Down