Skip to content

Commit

Permalink
Update DLA
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu committed Nov 23, 2020
1 parent 54b2adc commit 5e3f990
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions models/dla.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,22 @@ class Tree(nn.Module):
def __init__(self, block, in_channels, out_channels, level=1, stride=1):
super(Tree, self).__init__()
self.level = level
self.root = Root((level+1)*out_channels, out_channels)
if level == 1:
self.root = Root(2*out_channels, out_channels)
self.left_node = block(in_channels, out_channels, stride=stride)
self.right_node = block(out_channels, out_channels, stride=1)
else:
self.root = Root((level+2)*out_channels, out_channels)
for i in reversed(range(1, level)):
subtree = Tree(block, in_channels, out_channels,
level=i, stride=stride)
self.__setattr__('level_%d' % i, subtree)
self.prev_root = block(in_channels, out_channels, stride=stride)
self.left_node = block(out_channels, out_channels, stride=1)
self.right_node = block(out_channels, out_channels, stride=1)

def forward(self, x):
xs = []
xs = [self.prev_root(x)] if self.level > 1 else []
for i in reversed(range(1, self.level)):
level_i = self.__getattr__('level_%d' % i)
x = level_i(x)
Expand Down

0 comments on commit 5e3f990

Please sign in to comment.