Skip to content

Commit

Permalink
feature: video interpolation
Browse files Browse the repository at this point in the history
- uses rife algorithm to interpolate frames
  • Loading branch information
brycedrennan committed Jan 8, 2024
1 parent bb2dd45 commit 20da2a2
Show file tree
Hide file tree
Showing 9 changed files with 1,060 additions and 12 deletions.
28 changes: 23 additions & 5 deletions imaginairy/api/video_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchvision.transforms import ToTensor

from imaginairy import config
from imaginairy.enhancers.video_interpolation.rife.interpolate import interpolate_images
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import (
default,
Expand Down Expand Up @@ -91,7 +92,6 @@ def generate_video(
)
logger.warning(msg)

start_time = time.perf_counter()
output_fps = default(output_fps, fps_id)

video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None)
Expand Down Expand Up @@ -139,6 +139,7 @@ def generate_video(
expected_size = (vid_width, vid_height)
for _ in range(repetitions):
for input_path in all_img_paths:
start_time = time.perf_counter()
_seed = default(seed, random.randint(0, 1000000))
torch.manual_seed(_seed)
logger.info(
Expand Down Expand Up @@ -318,15 +319,32 @@ def save_video(samples: torch.Tensor, video_filename: str, output_fps: int):
os.system(f"ffmpeg -i {video_filename} -c:v libx264 {video_path_h264}")


def save_video_bounce(samples: torch.Tensor, video_filename: str, output_fps: int):
def save_video_bounce(
samples: torch.Tensor, video_filename: str, output_fps: int, interpolate_fps=60
):
frames_np = (
(torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8)
)

transition_duration = len(frames_np) / float(output_fps)
frames_pil = [Image.fromarray(frame) for frame in frames_np]
if interpolate_fps:
# bring it up to at least 60 fps
fps_multiplier = int(math.ceil(interpolate_fps / output_fps))
frames_pil = interpolate_images(frames_pil, fps_multiplier=fps_multiplier)

transition_duration_ms = transition_duration * 1000
logger.info(
f"Interpolated from {len(frames_np)} to {len(frames_pil)} frames ({fps_multiplier} multiplier)"
)
logger.info(
f"Making bounce animation with transition duration {transition_duration_ms:.1f}ms"
)
make_bounce_animation(
imgs=[Image.fromarray(frame) for frame in frames_np],
imgs=frames_pil,
outpath=video_filename,
end_pause_duration_ms=750,
transition_duration_ms=transition_duration_ms,
end_pause_duration_ms=100,
max_fps=60,
)


Expand Down
Empty file.
227 changes: 227 additions & 0 deletions imaginairy/enhancers/video_interpolation/rife/IFNet_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .warplayer import warp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)


def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True),
)


class Head(nn.Module):
def __init__(self):
super().__init__()
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1)
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3


class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super().__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)


class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super().__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)
)

def forward(self, x, flow=None, scale=1):
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
if flow is not None:
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(
tmp, scale_factor=scale, mode="bilinear", align_corners=False
)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
return flow, mask


class IFNet(nn.Module):
def __init__(self):
super().__init__()
self.block0 = IFBlock(7 + 16, c=192)
self.block1 = IFBlock(8 + 4 + 16, c=128)
self.block2 = IFBlock(8 + 4 + 16, c=96)
self.block3 = IFBlock(8 + 4 + 16, c=64)
self.encode = Head()
# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(
self,
x,
timestep=0.5,
scale_list=[8, 4, 2, 1],
training=False,
fastmode=True,
ensemble=False,
):
if training is False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
None,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1),
None,
scale=scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat(
(
warped_img0[:, :3],
warped_img1[:, :3],
wf0,
wf1,
timestep,
mask,
),
1,
),
flow,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
if not fastmode:
print("contextnet is removed")
"""
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
merged[3] = torch.clamp(merged[3] + res, 0, 1)
"""
return flow_list, mask_list[3], merged
50 changes: 50 additions & 0 deletions imaginairy/enhancers/video_interpolation/rife/RIFE_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from .IFNet_HDv3 import IFNet


class Model:
def __init__(self):
self.flownet = IFNet()
self.version: float

def eval(self):
self.flownet.eval()

def load_model(self, path, version: float):
from safetensors import safe_open

tensors = {}
with safe_open(path, framework="pt") as f: # type: ignore
for key in f.keys(): # noqa
tensors[key] = f.get_tensor(key)
self.flownet.load_state_dict(tensors, assign=True)
self.version = version

def load_model_old(self, path, rank=0):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param

if rank <= 0:
if torch.cuda.is_available():
self.flownet.load_state_dict(
convert(torch.load(f"{path}/flownet.pkl")), False
)
else:
self.flownet.load_state_dict(
convert(torch.load(f"{path}/flownet.pkl", map_location="cpu")),
False,
)

def inference(self, img0, img1, timestep=0.5, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
return merged[3]
Empty file.
Loading

0 comments on commit 20da2a2

Please sign in to comment.