Skip to content

Commit

Permalink
rcan pixel unshuffle progress
Browse files Browse the repository at this point in the history
  • Loading branch information
the-database committed Jan 8, 2025
1 parent 60fe421 commit 2b56162
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion traiNNer/archs/arch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
"df2k_ssim": 0.8140,
},
},
"rcan": {
"rcan unshuffle_mod=False": {
2: {
"div2k_psnr": 33.34,
"div2k_ssim": 0.9384,
Expand Down
28 changes: 26 additions & 2 deletions traiNNer/archs/rcan_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable

import torch
from spandrel.architectures.__arch_helpers.padding import pad_to_multiple
from spandrel.util import store_hyperparameters
from torch import Tensor, nn

Expand Down Expand Up @@ -256,10 +257,13 @@ def __init__(
reduction: int = 16,
res_scale: float = 1,
act_mode: str = "relu",
unshuffle_mod: bool = False,
conv: Callable[..., nn.Conv2d] = default_conv,
) -> None:
super().__init__()

self.scale = scale

if norm:
# RGB mean for DIV2K
self.rgb_range = rgb_range
Expand All @@ -273,7 +277,21 @@ def __init__(
self.add_mean = nn.Identity()

# define head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
unshuffle_mod = unshuffle_mod and scale < 4
self.downscale_factor = 1
if unshuffle_mod:
self.downscale_factor = 4 // scale
scale = 4
modules_head = [
nn.PixelUnshuffle(self.downscale_factor),
conv(
n_colors * self.downscale_factor * self.downscale_factor,
n_feats,
kernel_size,
),
]
else:
modules_head = [conv(n_colors, n_feats, kernel_size)]

# define body module
modules_body: list[nn.Module] = [
Expand Down Expand Up @@ -301,7 +319,12 @@ def __init__(
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)

def check_img_size(self, x: Tensor) -> Tensor:
return pad_to_multiple(x, self.downscale_factor, mode="reflect")

def forward(self, x: Tensor) -> Tensor:
_b, _c, h, w = x.shape
x = self.check_img_size(x)
x *= self.rgb_range
x = self.sub_mean(x)
x = self.head(x)
Expand All @@ -311,7 +334,8 @@ def forward(self, x: Tensor) -> Tensor:

x = self.tail(res)
x = self.add_mean(x)
return x / self.rgb_range
out = (x / self.rgb_range)[:, :, : h * self.scale, : w * self.scale]
return out


# @ARCH_REGISTRY.register()
Expand Down

0 comments on commit 2b56162

Please sign in to comment.