-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel_VGG.py
104 lines (92 loc) · 4.17 KB
/
model_VGG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
'''FPN in PyTorch.
See the paper "Feature Pyramid Networks for Object Detection" for more details.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import cfg
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class advancedEAST(nn.Module):
def __init__(self):
super(advancedEAST, self).__init__()
# Bottom-up layers
self.layer2 = self.make_layers([64, 64, 'M', 128, 128, 'M'], in_channels=3)
self.layer3 = self.make_layers([256, 256, 256, 'M'], in_channels=128)
self.layer4 = self.make_layers([512, 512, 512, 'M'], in_channels=256)
self.layer5 = self.make_layers([512, 512, 512, 'M'], in_channels=512)
# Top-down
self.merging1 = self.merging(i=2)
self.merging2 = self.merging(i=3)
self.merging3 = self.merging(i=4)
# before output layers
self.last_bn = nn.BatchNorm2d(32)
self.conv_last = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.inside_score_conv = nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0)
self.side_v_code_conv = nn.Conv2d(32, 2, kernel_size=1, stride=1, padding=0)
self.side_v_coord_conv = nn.Conv2d(32, 4, kernel_size=1, stride=1, padding=0)
# locked first two conv layers
if cfg.locked_layers:
i = 1
for m in self.layer2.children():
if isinstance(m, nn.Conv2d) and i <= 2:
print('冻结第{}层参数,层属性:{}'.format(i, m))
for param in m.parameters():
param.requires_grad = False
i += 1
def make_layers(self, cfg_list, in_channels=3, batch_norm=True): # VGG part
layers = []
for v in cfg_list:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
def merging(self, i=2):
in_size = {'2': 1024, '3': 384, '4': 192}
layers = [
nn.BatchNorm2d(in_size[str(i)]),
nn.Conv2d(in_size[str(i)], 128 // 2 ** (i - 2), kernel_size=1, stride=1, padding=0),
nn.ReLU(),
nn.BatchNorm2d(128 // 2 ** (i - 2)),
nn.Conv2d(128 // 2 ** (i - 2), 128 // 2 ** (i - 2), kernel_size=3, stride=1, padding=1),
nn.ReLU()]
return nn.Sequential(*layers)
def forward(self, x):
# Bottom-up
f4 = self.layer2(x) # 128
f3 = self.layer3(f4) # 256
f2 = self.layer4(f3) # 512
f1 = self.layer5(f2) # 512
# Top-down
h1 = f1
H1 = nn.UpsamplingNearest2d(scale_factor=2)(h1)
concat1 = torch.cat((H1, f2), axis=1) # 1024
h2 = self.merging1(concat1) # 128
H2 = nn.UpsamplingNearest2d(scale_factor=2)(h2)
concat2 = torch.cat((H2, f3), axis=1) # 128+256
h3 = self.merging2(concat2) # 64
H3 = nn.UpsamplingNearest2d(scale_factor=2)(h3)
concat3 = torch.cat((H3, f4), axis=1) # 64+128
h4 = self.merging3(concat3) # 32
# before output layers
bn = self.last_bn(h4)
before_output = F.relu(self.conv_last(bn))
inside_score = self.inside_score_conv(before_output)
side_v_code = self.side_v_code_conv(before_output)
side_v_coord = self.side_v_coord_conv(before_output)
east_detect = torch.cat((inside_score, side_v_code, side_v_coord), axis=1)
return east_detect
if __name__ == '__main__':
net = advancedEAST().to(device)
if cfg.model_summary:
try:
from torchsummary import summary
summary(net, input_size=(3, 128, 128))
except ImportError:
print("\"torchsummary\" not found, please install to visualize the model architecture.")
cfg.model_summary = False