Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve "Align Image to Reference"-node further #2956

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 102 additions & 72 deletions backend/src/nodes/impl/pytorch/rife/IFNet_HDv3_v4_14_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,99 +144,119 @@ def align_images(
blur_strength, # noqa: ANN001
ensemble, # noqa: ANN001
device, # noqa: ANN001
flow2=None, # noqa: ANN001
img1_blurred=None, # noqa: ANN001
):
# optional blur
if blur_strength is not None and blur_strength > 0:
blur = transforms.GaussianBlur(
kernel_size=(5, 5), sigma=(blur_strength, blur_strength)
)
img0_blurred = blur(img0)
img1_blurred = blur(img1)
else:
img0_blurred = img0
img1_blurred = img1

f0 = self.encode(img0_blurred[:, :3])
f1 = self.encode(img1_blurred[:, :3])
flow_list = []
mask_list = []
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat(
(img0_blurred[:, :3], img1_blurred[:, :3], f0, f1, timestep), 1
),
None,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
def compute_flow(
img0_blurred: np.ndarray, img1_blurred: np.ndarray, timestep: float
) -> None:
f0 = self.encode(img0_blurred[:, :3])
f1 = self.encode(img1_blurred[:, :3])
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
f1,
img1_blurred[:, :3],
f0,
1 - timestep,
f1,
timestep,
),
1,
),
None,
scale=scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2], device)
wf1 = warp(f1, flow[:, 2:4], device)
fd, m0 = block[i](
torch.cat(
(
img0_blurred[:, :3],
img1_blurred[:, :3],
wf0,
wf1,
timestep,
mask,
),
1,
),
flow,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
if ensemble:
f_, m_ = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
f1,
f0,
1 - timestep,
),
1,
),
None,
scale=scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2], device)
wf1 = warp(f1, flow[:, 2:4], device)
fd, m0 = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
wf1,
img1_blurred[:, :3],
wf0,
1 - timestep,
-mask,
wf1,
timestep,
mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
flow,
scale=scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
if ensemble:
f_, m_ = block[i](
torch.cat(
(
img1_blurred[:, :3],
img0_blurred[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
return flow

# Optional blur
if blur_strength is not None and blur_strength > 0:
blur = transforms.GaussianBlur(
kernel_size=(5, 5), sigma=(blur_strength, blur_strength)
)
img0_blurred = blur(img0)
if img1_blurred is None:
img1_blurred = blur(img1)
else:
img0_blurred = img0
img1_blurred = img1

# align image to reference
flow1 = compute_flow(img0_blurred, img1_blurred, timestep)

# align image to itself
if flow2 is None:
flow2 = compute_flow(img1_blurred, img1_blurred, timestep)

# subtract flow2 from flow1 to compensate for shifts
compensated_flow = flow1 - flow2

# apply warp to original image
aligned_img0 = warp(img0, flow_list[-1][:, :2], device)
# warp image with the compensated flow
aligned_img0 = warp(img0, compensated_flow[:, :2], device)

# add clamp here instead of in warplayer script, as it changes the output there
aligned_img0 = aligned_img0.clamp(min=0.0, max=1.0)
return aligned_img0, flow_list[-1]
aligned_img0 = aligned_img0.clamp(0, 1)
return aligned_img0, compensated_flow, flow2, img1_blurred

def forward(
self,
Expand All @@ -262,9 +282,19 @@ def forward(
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) # type: ignore

flow2 = None
img1_blurred = None
for _iteration in range(num_iterations):
aligned_img0, flow = self.align_images(
img0, img1, timestep, scale_list, blur_strength, ensemble, device
aligned_img0, flow, flow2, img1_blurred = self.align_images(
img0,
img1,
timestep,
scale_list,
blur_strength,
ensemble,
device,
flow2,
img1_blurred,
)
img0 = aligned_img0 # use the aligned image as img0 for the next iteration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,10 @@ def align_images(

source_h, source_w, _ = get_h_w_c(source_img)

# resize, then shift reference left because rife shifts slightly to the right
# resize image to reference
target_img_resized = resize(
target_img, (source_w, source_h), filter=ResizeFilter.LANCZOS
)
target_img_resized = np.roll(target_img_resized, -1, axis=1)
target_img_resized[:, -1] = target_img_resized[:, -2]

# padding because rife can only work with multiples of 32 (changes with precision mode)
pad_h, pad_w = calculate_padding(source_h, source_w, precision_mode)
Expand Down
Loading