Skip to content

Commit

Permalink
feature: add tiling ability for upscalers
Browse files Browse the repository at this point in the history
-accept paths and urls for upscale weights
  • Loading branch information
brycedrennan committed Jan 23, 2024
1 parent bd5be80 commit bb7b636
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 18 deletions.
4 changes: 4 additions & 0 deletions imaginairy/cli/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import click

from imaginairy.utils.log_utils import configure_logging

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -83,6 +85,8 @@ def upscale_cmd(
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import glob_expand_paths

configure_logging()

os.makedirs(outdir, exist_ok=True)
image_filepaths = glob_expand_paths(image_filepaths)
for p in tqdm(image_filepaths):
Expand Down
75 changes: 57 additions & 18 deletions imaginairy/enhancers/upscale.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,84 @@
import logging

import torch
import torchvision.transforms.functional as F
from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader
from torchvision import transforms

from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.tile_up import tile_process

upscale_models = {
upscale_model_lookup = {
# RealESRGAN
"ultrasharp": "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth",
"realesrgan": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"HAT": "https://huggingface.co/Acly/hat/resolve/main/HAT_SRx4_ImageNet-pretrain.pth?download=true",
"realesrgan-x4-plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"realesrgan-x2-plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
# ESRGAN
"esrgan-x4": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
# HAT
"real-hat": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/Real_HAT_GAN_SRx4.safetensors",
"real-hat-sharper": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/Real_HAT_GAN_sharper.safetensors",
"4xNomos8kHAT-L": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/4xNomos8kHAT-L_otf.safetensors",
}
logger = logging.getLogger(__name__)


def upscale_image(
img: LazyLoadingImage | Image.Image,
upscaler_model: str = "realesrgan",
tile_size=512,
tile_pad=50,
repetition=2,
device=None,
):
device = device or get_device()
# device="cpu"

def upscale_image(img: LazyLoadingImage, upscaler_model: str = "realesrgan"):
model_path = get_cached_url_path(upscale_models[upscaler_model])
if upscaler_model in upscale_model_lookup:
model_url = upscale_model_lookup[upscaler_model]
model_path = get_cached_url_path(model_url)
elif upscaler_model.startswith(("https://", "http://")):
model_url = upscaler_model
model_path = get_cached_url_path(model_url)
else:
model_path = upscaler_model
model = ModelLoader().load_from_file(model_path)
logger.info(
f"Upscaling from {img.width}x{img.height} to {img.width * model.scale}x{img.height * model.scale}"
)
print(f"Upscaling image with model {model.architecture}@{upscaler_model}")

assert isinstance(model, ImageModelDescriptor)

device = get_device()
model.to(device).eval()

image_tensor = load_image(img, device)

def process(image: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return model(image)
model.to(torch.device(device)).eval()

upscaled_img = process(image_tensor)
image_tensor = load_image(img).to(device)
with torch.no_grad():
for _ in range(repetition):
if tile_size > 0:
image_tensor = tile_process(
image_tensor,
scale=model.scale,
model=model,
tile_size=tile_size,
tile_pad=tile_pad,
)
else:
image_tensor = model(image_tensor)

upscaled_img = upscaled_img.squeeze(0)
image = F.to_pil_image(upscaled_img)
image_tensor = image_tensor.squeeze(0)
image = F.to_pil_image(image_tensor)
image = image.resize((img.width * model.scale, img.height * model.scale))

return image


def load_image(img: LazyLoadingImage, device: str):
def load_image(img: LazyLoadingImage):
transform = transforms.ToTensor()
image_tensor = transform(img.as_pillow())

image_tensor = image_tensor.unsqueeze(0)
return image_tensor.to(device)
return image_tensor.to(get_device())
10 changes: 10 additions & 0 deletions imaginairy/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ def disable_common_warnings():
"ignore", category=UserWarning, message=r"Arguments other than a weight.*"
)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r".*?torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument..*?",
)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r".*?is not currently supported on the MPS backend and will fall back.*?",
)


def suppress_annoying_logs_and_warnings():
Expand Down
87 changes: 87 additions & 0 deletions imaginairy/utils/tile_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
import math

import torch
from torch import Tensor

logger = logging.getLogger(__name__)


def tile_process(
img: Tensor,
scale: int,
model: torch.nn.Module,
tile_size: int = 512,
tile_pad: int = 50,
) -> Tensor:
"""
Process an image by tiling it, processing each tile, and then merging them back into one image.
Args:
img (Tensor): The input image tensor.
scale (int): The scale factor for the image.
tile_size (int): The size of each tile.
tile_pad (int): The padding for each tile.
model (torch.nn.Module): The model used for processing the tile.
Returns:
Tensor: The processed output image.
"""
batch, channel, height, width = img.shape
output_height = height * scale
output_width = width * scale
output_shape = (batch, channel, output_height, output_width)

# Initialize the output tensor
output = img.new_zeros(output_shape)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
logger.info(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles")

for y in range(tiles_y):
for x in range(tiles_x):
# Calculate the input tile coordinates with and without padding
ofs_x, ofs_y = x * tile_size, y * tile_size
input_start_x, input_end_x = ofs_x, min(ofs_x + tile_size, width)
input_start_y, input_end_y = ofs_y, min(ofs_y + tile_size, height)
padded_start_x, padded_end_x = (
max(input_start_x - tile_pad, 0),
min(input_end_x + tile_pad, width),
)
padded_start_y, padded_end_y = (
max(input_start_y - tile_pad, 0),
min(input_end_y + tile_pad, height),
)

# Extract the input tile
input_tile = img[
:, :, padded_start_y:padded_end_y, padded_start_x:padded_end_x
]

# Process the tile
with torch.no_grad():
output_tile = model(input_tile)

# Calculate the output tile coordinates
output_start_x, output_end_x = input_start_x * scale, input_end_x * scale
output_start_y, output_end_y = input_start_y * scale, input_end_y * scale
tile_output_start_x = (input_start_x - padded_start_x) * scale
tile_output_end_x = (
tile_output_start_x + (input_end_x - input_start_x) * scale
)
tile_output_start_y = (input_start_y - padded_start_y) * scale
tile_output_end_y = (
tile_output_start_y + (input_end_y - input_start_y) * scale
)

# Place the processed tile in the output image
output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
tile_output_start_y:tile_output_end_y,
tile_output_start_x:tile_output_end_x,
]

return output

0 comments on commit bb7b636

Please sign in to comment.