diff --git a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py index e72f2b7f0..a1bfd568b 100644 --- a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py +++ b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py @@ -144,99 +144,119 @@ def align_images( blur_strength, # noqa: ANN001 ensemble, # noqa: ANN001 device, # noqa: ANN001 + flow2=None, # noqa: ANN001 + img1_blurred=None, # noqa: ANN001 ): - # optional blur - if blur_strength is not None and blur_strength > 0: - blur = transforms.GaussianBlur( - kernel_size=(5, 5), sigma=(blur_strength, blur_strength) - ) - img0_blurred = blur(img0) - img1_blurred = blur(img1) - else: - img0_blurred = img0 - img1_blurred = img1 - - f0 = self.encode(img0_blurred[:, :3]) - f1 = self.encode(img1_blurred[:, :3]) - flow_list = [] - mask_list = [] - flow = None - mask = None - block = [self.block0, self.block1, self.block2, self.block3] - for i in range(4): - if flow is None: - flow, mask = block[i]( - torch.cat( - (img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1 - ), - None, - scale=scale_list[i], - ) - if ensemble: - f_, m_ = block[i]( + def compute_flow( + img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: float + ) -> None: + f0 = self.encode(img0_blurred[:, :3]) + f1 = self.encode(img1_blurred[:, :3]) + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i]( torch.cat( ( - img1_blurred[:, :3], img0_blurred[:, :3], - f1, + img1_blurred[:, :3], f0, - 1 - timestep, + f1, + timestep, ), 1, ), None, scale=scale_list[i], ) - flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 - mask = (mask + (-m_)) / 2 - else: - wf0 = warp(f0, flow[:, :2], device) - wf1 = warp(f1, flow[:, 2:4], device) - fd, m0 = block[i]( - torch.cat( - ( - img0_blurred[:, :3], - img1_blurred[:, :3], - wf0, - wf1, - timestep, - mask, - ), - 1, - ), - flow, - scale=scale_list[i], - ) - if ensemble: - f_, m_ = block[i]( + if ensemble: + f_, m_ = block[i]( + torch.cat( + ( + img1_blurred[:, :3], + img0_blurred[:, :3], + f1, + f0, + 1 - timestep, + ), + 1, + ), + None, + scale=scale_list[i], + ) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + else: + wf0 = warp(f0, flow[:, :2], device) + wf1 = warp(f1, flow[:, 2:4], device) + fd, m0 = block[i]( torch.cat( ( - img1_blurred[:, :3], img0_blurred[:, :3], - wf1, + img1_blurred[:, :3], wf0, - 1 - timestep, - -mask, + wf1, + timestep, + mask, ), 1, ), - torch.cat((flow[:, 2:4], flow[:, :2]), 1), + flow, scale=scale_list[i], ) - fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 - mask = (m0 + (-m_)) / 2 - else: - mask = m0 - flow = flow + fd - mask_list.append(mask) - flow_list.append(flow) + if ensemble: + f_, m_ = block[i]( + torch.cat( + ( + img1_blurred[:, :3], + img0_blurred[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd + return flow + + # Optional blur + if blur_strength is not None and blur_strength > 0: + blur = transforms.GaussianBlur( + kernel_size=(5, 5), sigma=(blur_strength, blur_strength) + ) + img0_blurred = blur(img0) + if img1_blurred is None: + img1_blurred = blur(img1) + else: + img0_blurred = img0 + img1_blurred = img1 + + # align image to reference + flow1 = compute_flow(img0_blurred, img1_blurred, timestep) + + # align image to itself + if flow2 is None: + flow2 = compute_flow(img1_blurred, img1_blurred, timestep) + + # subtract flow2 from flow1 to compensate for shifts + compensated_flow = flow1 - flow2 - # apply warp to original image - aligned_img0 = warp(img0, flow_list[-1][:, :2], device) + # warp image with the compensated flow + aligned_img0 = warp(img0, compensated_flow[:, :2], device) # add clamp here instead of in warplayer script, as it changes the output there - aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0) - return aligned_img0, flow_list[-1] + aligned_img0 = aligned_img0.clamp(0, 1) + return aligned_img0, compensated_flow, flow2, img1_blurred def forward( self, @@ -262,9 +282,19 @@ def forward( else: timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) # type: ignore + flow2 = None + img1_blurred = None for _iteration in range(num_iterations): - aligned_img0, flow = self.align_images( - img0, img1, timestep, scale_list, blur_strength, ensemble, device + aligned_img0, flow, flow2, img1_blurred = self.align_images( + img0, + img1, + timestep, + scale_list, + blur_strength, + ensemble, + device, + flow2, + img1_blurred, ) img0 = aligned_img0 # use the aligned image as img0 for the next iteration diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py index 43d18954e..ccb9e81db 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py @@ -105,12 +105,10 @@ def align_images( source_h, source_w, _ = get_h_w_c(source_img) - # resize, then shift reference left because rife shifts slightly to the right + # resize image to reference target_img_resized = resize( target_img, (source_w, source_h), filter=ResizeFilter.LANCZOS ) - target_img_resized = np.roll(target_img_resized, -1, axis=1) - target_img_resized[:, -1] = target_img_resized[:, -2] # padding because rife can only work with multiples of 32 (changes with precision mode) pad_h, pad_w = calculate_padding(source_h, source_w, precision_mode)