From 2b56162be360ea0ebb9960c06dad6942acd9c049 Mon Sep 17 00:00:00 2001 From: the-database <25811902+the-database@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:48:11 -0500 Subject: [PATCH] rcan pixel unshuffle progress --- traiNNer/archs/arch_info.py | 2 +- traiNNer/archs/rcan_arch.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/traiNNer/archs/arch_info.py b/traiNNer/archs/arch_info.py index 17c55918..ddf5a5c1 100644 --- a/traiNNer/archs/arch_info.py +++ b/traiNNer/archs/arch_info.py @@ -334,7 +334,7 @@ "df2k_ssim": 0.8140, }, }, - "rcan": { + "rcan unshuffle_mod=False": { 2: { "div2k_psnr": 33.34, "div2k_ssim": 0.9384, diff --git a/traiNNer/archs/rcan_arch.py b/traiNNer/archs/rcan_arch.py index 54f4b280..52e6bcaa 100644 --- a/traiNNer/archs/rcan_arch.py +++ b/traiNNer/archs/rcan_arch.py @@ -2,6 +2,7 @@ from collections.abc import Callable import torch +from spandrel.architectures.__arch_helpers.padding import pad_to_multiple from spandrel.util import store_hyperparameters from torch import Tensor, nn @@ -256,10 +257,13 @@ def __init__( reduction: int = 16, res_scale: float = 1, act_mode: str = "relu", + unshuffle_mod: bool = False, conv: Callable[..., nn.Conv2d] = default_conv, ) -> None: super().__init__() + self.scale = scale + if norm: # RGB mean for DIV2K self.rgb_range = rgb_range @@ -273,7 +277,21 @@ def __init__( self.add_mean = nn.Identity() # define head module - modules_head = [conv(n_colors, n_feats, kernel_size)] + unshuffle_mod = unshuffle_mod and scale < 4 + self.downscale_factor = 1 + if unshuffle_mod: + self.downscale_factor = 4 // scale + scale = 4 + modules_head = [ + nn.PixelUnshuffle(self.downscale_factor), + conv( + n_colors * self.downscale_factor * self.downscale_factor, + n_feats, + kernel_size, + ), + ] + else: + modules_head = [conv(n_colors, n_feats, kernel_size)] # define body module modules_body: list[nn.Module] = [ @@ -301,7 +319,12 @@ def __init__( self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) + def check_img_size(self, x: Tensor) -> Tensor: + return pad_to_multiple(x, self.downscale_factor, mode="reflect") + def forward(self, x: Tensor) -> Tensor: + _b, _c, h, w = x.shape + x = self.check_img_size(x) x *= self.rgb_range x = self.sub_mean(x) x = self.head(x) @@ -311,7 +334,8 @@ def forward(self, x: Tensor) -> Tensor: x = self.tail(res) x = self.add_mean(x) - return x / self.rgb_range + out = (x / self.rgb_range)[:, :, : h * self.scale, : w * self.scale] + return out # @ARCH_REGISTRY.register()