diff --git a/code/dcn/modules/deform_conv.py b/code/dcn/modules/deform_conv.py index ba77ee6..0e678b3 100644 --- a/code/dcn/modules/deform_conv.py +++ b/code/dcn/modules/deform_conv.py @@ -143,9 +143,10 @@ def forward(self, input, temp): dimension_T = 'T' in self.dimension dimension_H = 'H' in self.dimension dimension_W = 'W' in self.dimension + b, c, t, h, w = temp.shape if self.length == 2: - b, c, t, h, w = temp.shape - offset = temp.clone().resize_(b, 81, t, h, w) + temp1 = temp.clone()[:, 0:81 - c, :, :, :] + offset = torch.cat((temp.clone(), temp1), dim=1) if dimension_T == False: for i in range( self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]): @@ -166,6 +167,8 @@ def forward(self, input, temp): offset[:, i * 3 + 2, :, :, :] = 0 # W if self.length == 1: + temp1 = temp.clone() + offset = torch.cat((temp.clone(), temp1, temp1), dim=1) if dimension_T == True: for i in range( self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]): @@ -229,9 +232,9 @@ def forward(self, input): dimension_H = 'H' in self.dimension dimension_W = 'W' in self.dimension b, c, t, h, w = temp.shape - temp1 = temp.clone()[:,0:81-c,:,:,:] - offset = torch.cat((temp.clone(),temp1),dim=1) if self.length == 2: + temp1 = temp.clone()[:, 0:81 - c, :, :, :] + offset = torch.cat((temp.clone(), temp1), dim=1) if dimension_T == False: for i in range( self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]): @@ -252,6 +255,8 @@ def forward(self, input): offset[:, i * 3 + 2, :, :, :] = 0 # W if self.length == 1: + temp1 = temp.clone() + offset = torch.cat((temp.clone(), temp1, temp1), dim=1) if dimension_T == True: for i in range( self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]):