From f6d01d80af5f2a0f7709d0591b85dada76262213 Mon Sep 17 00:00:00 2001 From: tepete Date: Fri, 21 Jun 2024 10:26:06 +0200 Subject: [PATCH 1/6] remove manual 1 pixel shift to the left --- .../pytorch/processing/align_image_to_reference.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py index 43d18954e..89089d3b4 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py @@ -105,13 +105,11 @@ def align_images( source_h, source_w, _ = get_h_w_c(source_img) - # resize, then shift reference left because rife shifts slightly to the right + # resize image to reference target_img_resized = resize( target_img, (source_w, source_h), filter=ResizeFilter.LANCZOS ) - target_img_resized = np.roll(target_img_resized, -1, axis=1) - target_img_resized[:, -1] = target_img_resized[:, -2] - + # padding because rife can only work with multiples of 32 (changes with precision mode) pad_h, pad_w = calculate_padding(source_h, source_w, precision_mode) top_pad = pad_h // 2 From 8ea404026b81473b4bcd481206c90b4c5bc3535c Mon Sep 17 00:00:00 2001 From: tepete Date: Fri, 21 Jun 2024 10:55:31 +0200 Subject: [PATCH 2/6] align reference to itself to compensate for shifts --- .../pytorch/rife/IFNet_HDv3_v4_14_align.py | 167 ++++++++++-------- 1 file changed, 90 insertions(+), 77 deletions(-) diff --git a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py index e72f2b7f0..9601ea0cd 100644 --- a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py +++ b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py @@ -144,99 +144,110 @@ def align_images( blur_strength, # noqa: ANN001 ensemble, # noqa: ANN001 device, # noqa: ANN001 + flow2=None, # noqa: ANN001 + img1_blurred=None # noqa: ANN001 ): - # optional blur - if blur_strength is not None and blur_strength > 0: - blur = transforms.GaussianBlur( - kernel_size=(5, 5), sigma=(blur_strength, blur_strength) - ) - img0_blurred = blur(img0) - img1_blurred = blur(img1) - else: - img0_blurred = img0 - img1_blurred = img1 - - f0 = self.encode(img0_blurred[:, :3]) - f1 = self.encode(img1_blurred[:, :3]) - flow_list = [] - mask_list = [] - flow = None - mask = None - block = [self.block0, self.block1, self.block2, self.block3] - for i in range(4): - if flow is None: - flow, mask = block[i]( - torch.cat( - (img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1 - ), - None, - scale=scale_list[i], - ) - if ensemble: - f_, m_ = block[i]( + def compute_flow(img0_blurred, img1_blurred, timestep): + f0 = self.encode(img0_blurred[:, :3]) + f1 = self.encode(img1_blurred[:, :3]) + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i]( torch.cat( - ( - img1_blurred[:, :3], - img0_blurred[:, :3], - f1, - f0, - 1 - timestep, - ), - 1, + (img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1 ), None, scale=scale_list[i], ) - flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 - mask = (mask + (-m_)) / 2 - else: - wf0 = warp(f0, flow[:, :2], device) - wf1 = warp(f1, flow[:, 2:4], device) - fd, m0 = block[i]( - torch.cat( - ( - img0_blurred[:, :3], - img1_blurred[:, :3], - wf0, - wf1, - timestep, - mask, - ), - 1, - ), - flow, - scale=scale_list[i], - ) - if ensemble: - f_, m_ = block[i]( + if ensemble: + f_, m_ = block[i]( + torch.cat( + ( + img1_blurred[:, :3], + img0_blurred[:, :3], + f1, + f0, + 1 - timestep, + ), + 1, + ), + None, + scale=scale_list[i], + ) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + else: + wf0 = warp(f0, flow[:, :2], device) + wf1 = warp(f1, flow[:, 2:4], device) + fd, m0 = block[i]( torch.cat( ( - img1_blurred[:, :3], img0_blurred[:, :3], - wf1, + img1_blurred[:, :3], wf0, - 1 - timestep, - -mask, + wf1, + timestep, + mask, ), 1, ), - torch.cat((flow[:, 2:4], flow[:, :2]), 1), + flow, scale=scale_list[i], ) - fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 - mask = (m0 + (-m_)) / 2 - else: - mask = m0 - flow = flow + fd - mask_list.append(mask) - flow_list.append(flow) + if ensemble: + f_, m_ = block[i]( + torch.cat( + ( + img1_blurred[:, :3], + img0_blurred[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd + return flow + + # Optional blur + if blur_strength is not None and blur_strength > 0: + blur = transforms.GaussianBlur( + kernel_size=(5, 5), sigma=(blur_strength, blur_strength) + ) + img0_blurred = blur(img0) + if img1_blurred is None: + img1_blurred = blur(img1) + else: + img0_blurred = img0 + img1_blurred = img1 + + # align image to reference + flow1 = compute_flow(img0_blurred, img1_blurred, timestep) + + # align image to itself + if flow2 is None: + flow2 = compute_flow(img1_blurred, img1_blurred, timestep) + + # subtract flow2 from flow1 to compensate for shifts + new_flow = flow1 - flow2 - # apply warp to original image - aligned_img0 = warp(img0, flow_list[-1][:, :2], device) + # warp original unblurred image with the new flow + aligned_img0 = warp(img0, new_flow[:, :2], device) # add clamp here instead of in warplayer script, as it changes the output there - aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0) - return aligned_img0, flow_list[-1] + aligned_img0 = aligned_img0.clamp(0, 1) + return aligned_img0, new_flow, flow2, img1_blurred def forward( self, @@ -262,9 +273,11 @@ def forward( else: timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) # type: ignore + flow2 = None + img1_blurred = None for _iteration in range(num_iterations): - aligned_img0, flow = self.align_images( - img0, img1, timestep, scale_list, blur_strength, ensemble, device + aligned_img0, flow, flow2, img1_blurred = self.align_images( + img0, img1, timestep, scale_list, blur_strength, ensemble, device, flow2, img1_blurred ) img0 = aligned_img0 # use the aligned image as img0 for the next iteration From d9b97e9da6e0215283956c42dc3ef7c12989a059 Mon Sep 17 00:00:00 2001 From: tepete Date: Thu, 22 Aug 2024 02:33:50 +0200 Subject: [PATCH 3/6] try fixing ruff errors --- .../nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py index 9601ea0cd..c8293771a 100644 --- a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py +++ b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py @@ -147,7 +147,7 @@ def align_images( flow2=None, # noqa: ANN001 img1_blurred=None # noqa: ANN001 ): - def compute_flow(img0_blurred, img1_blurred, timestep): + def compute_flow(img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: float) -> None: f0 = self.encode(img0_blurred[:, :3]) f1 = self.encode(img1_blurred[:, :3]) flow = None @@ -240,14 +240,14 @@ def compute_flow(img0_blurred, img1_blurred, timestep): flow2 = compute_flow(img1_blurred, img1_blurred, timestep) # subtract flow2 from flow1 to compensate for shifts - new_flow = flow1 - flow2 + compensated_flow = flow1 - flow2 - # warp original unblurred image with the new flow - aligned_img0 = warp(img0, new_flow[:, :2], device) + # warp image with the compensated flow + aligned_img0 = warp(img0, compensated_flow[:, :2], device) # add clamp here instead of in warplayer script, as it changes the output there aligned_img0 = aligned_img0.clamp(0, 1) - return aligned_img0, new_flow, flow2, img1_blurred + return aligned_img0, compensated_flow, flow2, img1_blurred def forward( self, From d2df79825f6bcc4faf51796007c3cb8567281992 Mon Sep 17 00:00:00 2001 From: tepete Date: Thu, 22 Aug 2024 02:53:41 +0200 Subject: [PATCH 4/6] try fixing ruff errors --- .../pytorch/rife/IFNet_HDv3_v4_14_align.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py index c8293771a..a1bfd568b 100644 --- a/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py +++ b/backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py @@ -145,9 +145,11 @@ def align_images( ensemble, # noqa: ANN001 device, # noqa: ANN001 flow2=None, # noqa: ANN001 - img1_blurred=None # noqa: ANN001 + img1_blurred=None, # noqa: ANN001 ): - def compute_flow(img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: float) -> None: + def compute_flow( + img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: float + ) -> None: f0 = self.encode(img0_blurred[:, :3]) f1 = self.encode(img1_blurred[:, :3]) flow = None @@ -157,7 +159,14 @@ def compute_flow(img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: f if flow is None: flow, mask = block[i]( torch.cat( - (img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1 + ( + img0_blurred[:, :3], + img1_blurred[:, :3], + f0, + f1, + timestep, + ), + 1, ), None, scale=scale_list[i], @@ -277,7 +286,15 @@ def forward( img1_blurred = None for _iteration in range(num_iterations): aligned_img0, flow, flow2, img1_blurred = self.align_images( - img0, img1, timestep, scale_list, blur_strength, ensemble, device, flow2, img1_blurred + img0, + img1, + timestep, + scale_list, + blur_strength, + ensemble, + device, + flow2, + img1_blurred, ) img0 = aligned_img0 # use the aligned image as img0 for the next iteration From 7a7c565d7619956b4a537787881b43af6083b735 Mon Sep 17 00:00:00 2001 From: tepete Date: Thu, 22 Aug 2024 02:55:43 +0200 Subject: [PATCH 5/6] try fixing furr errors --- .../pytorch/processing/align_image_to_reference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py index 89089d3b4..478fcc9b7 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py @@ -15,7 +15,6 @@ import numpy as np import requests import torch - from api import NodeContext from nodes.impl.pytorch.rife.IFNet_HDv3_v4_14_align import IFNet from nodes.impl.pytorch.utils import np2tensor, safe_cuda_cache_empty, tensor2np @@ -109,7 +108,7 @@ def align_images( target_img_resized = resize( target_img, (source_w, source_h), filter=ResizeFilter.LANCZOS ) - + # padding because rife can only work with multiples of 32 (changes with precision mode) pad_h, pad_w = calculate_padding(source_h, source_w, precision_mode) top_pad = pad_h // 2 From 4ad9b804510c56bdc6dbc98d49dbae9def440c7d Mon Sep 17 00:00:00 2001 From: tepete Date: Thu, 22 Aug 2024 03:05:06 +0200 Subject: [PATCH 6/6] try fixing ruff errors --- .../pytorch/processing/align_image_to_reference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py index 478fcc9b7..ccb9e81db 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/align_image_to_reference.py @@ -15,6 +15,7 @@ import numpy as np import requests import torch + from api import NodeContext from nodes.impl.pytorch.rife.IFNet_HDv3_v4_14_align import IFNet from nodes.impl.pytorch.utils import np2tensor, safe_cuda_cache_empty, tensor2np