-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathmasking_consistency_module.py
151 lines (131 loc) · 5.62 KB
/
masking_consistency_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# ---------------------------------------------------------------
# Copyright (c) 2022 ETH Zurich, Lukas Hoyer. All rights reserved.
# Licensed under the Apache License, Version 2.0
# ---------------------------------------------------------------
import random
import torch
from torch.nn import Module
from mmseg.models.uda.teacher_module import EMATeacher
from mmseg.models.utils.dacs_transforms import get_mean_std, strong_transform
from mmseg.models.utils.masking_transforms import build_mask_generator
class MaskingConsistencyModule(Module):
def __init__(self, require_teacher, cfg):
super(MaskingConsistencyModule, self).__init__()
self.source_only = cfg.get('source_only', False)
self.max_iters = cfg['max_iters']
self.color_jitter_s = cfg['color_jitter_strength']
self.color_jitter_p = cfg['color_jitter_probability']
self.mask_mode = cfg['mask_mode']
self.mask_alpha = cfg['mask_alpha']
self.mask_pseudo_threshold = cfg['mask_pseudo_threshold']
self.mask_lambda = cfg['mask_lambda']
self.mask_gen = build_mask_generator(cfg['mask_generator'])
assert self.mask_mode in [
'separate', 'separatesrc', 'separatetrg', 'separateaug',
'separatesrcaug', 'separatetrgaug'
]
self.teacher = None
if require_teacher or \
self.mask_alpha != 'same' or \
self.mask_pseudo_threshold != 'same':
self.teacher = EMATeacher(use_mask_params=True, cfg=cfg)
self.debug = False
self.debug_output = {}
def update_weights(self, model, iter):
if self.teacher is not None:
self.teacher.update_weights(model, iter)
def update_debug_state(self):
if self.teacher is not None:
self.teacher.debug = self.debug
def __call__(self,
model,
img,
img_metas,
gt_semantic_seg,
target_img,
target_img_metas,
valid_pseudo_mask,
pseudo_label=None,
pseudo_weight=None):
self.update_debug_state()
self.debug_output = {}
model.debug_output = {}
dev = img.device
means, stds = get_mean_std(img_metas, dev)
if not self.source_only:
# Share the pseudo labels with the host UDA method
if self.teacher is None:
assert self.mask_alpha == 'same'
assert self.mask_pseudo_threshold == 'same'
assert pseudo_label is not None
assert pseudo_weight is not None
masked_plabel = pseudo_label
masked_pweight = pseudo_weight
# Use a separate EMA teacher for MIC
else:
masked_plabel, masked_pweight = \
self.teacher(
target_img, target_img_metas, valid_pseudo_mask)
if self.debug:
self.debug_output['Mask Teacher'] = {
'Img': target_img.detach(),
'Pseudo Label': masked_plabel.cpu().numpy(),
'Pseudo Weight': masked_pweight.cpu().numpy(),
}
# Don't use target images at all
if self.source_only:
masked_img = img
masked_lbl = gt_semantic_seg
b, _, h, w = gt_semantic_seg.shape
masked_seg_weight = None
# Use 1x source image and 1x target image for MIC
elif self.mask_mode in ['separate', 'separateaug']:
assert img.shape[0] == 2
masked_img = torch.stack([img[0], target_img[0]])
masked_lbl = torch.stack(
[gt_semantic_seg[0], masked_plabel[0].unsqueeze(0)])
gt_pixel_weight = torch.ones(masked_pweight[0].shape, device=dev)
masked_seg_weight = torch.stack(
[gt_pixel_weight, masked_pweight[0]])
# Use only source images for MIC
elif self.mask_mode in ['separatesrc', 'separatesrcaug']:
masked_img = img
masked_lbl = gt_semantic_seg
masked_seg_weight = None
# Use only target images for MIC
elif self.mask_mode in ['separatetrg', 'separatetrgaug']:
masked_img = target_img
masked_lbl = masked_plabel.unsqueeze(1)
masked_seg_weight = masked_pweight
else:
raise NotImplementedError(self.mask_mode)
# Apply color augmentation
if 'aug' in self.mask_mode:
strong_parameters = {
'mix': None,
'color_jitter': random.uniform(0, 1),
'color_jitter_s': self.color_jitter_s,
'color_jitter_p': self.color_jitter_p,
'blur': random.uniform(0, 1),
'mean': means[0].unsqueeze(0),
'std': stds[0].unsqueeze(0)
}
masked_img, _ = strong_transform(
strong_parameters, data=masked_img.clone())
# Apply masking to image
masked_img = self.mask_gen.mask_image(masked_img)
# Train on masked images
masked_loss = model.forward_train(
masked_img,
img_metas,
masked_lbl,
seg_weight=masked_seg_weight,
)
if self.mask_lambda != 1:
masked_loss['decode.loss_seg'] *= self.mask_lambda
if self.debug:
self.debug_output['Masked'] = model.debug_output
if masked_seg_weight is not None:
self.debug_output['Masked']['PL Weight'] = \
masked_seg_weight.cpu().numpy()
return masked_loss