Skip to content

Commit

Permalink
Merge pull request #220 from the-database/dev
Browse files Browse the repository at this point in the history
default to bf16 for rcan, moesr
  • Loading branch information
the-database authored Jan 7, 2025
2 parents 8a28d69 + 61b7e79 commit 5f8e6d4
Show file tree
Hide file tree
Showing 15 changed files with 45 additions and 16 deletions.
2 changes: 1 addition & 1 deletion options/test/MoESR/MoESR2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
name: 4x_MoESR2
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only.
num_gpu: auto


Expand Down
2 changes: 1 addition & 1 deletion options/test/RCAN/RCAN.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
name: 4x_RCAN
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only.
num_gpu: auto


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_MoESR2_OTF_bicubic_ms_ssim_l1
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: false # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/MoESR/MoESR2_OTF_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_MoESR2_OTF_finetune
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: false # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/MoESR/MoESR2_OTF_fromscratch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_MoESR2_OTF_fromscratch
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: false # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/MoESR/MoESR2_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_MoESR2
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: false # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/MoESR/MoESR2_fromscratch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_MoESR2
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: false # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_RCAN_OTF_bicubic_ms_ssim_l1
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/RCAN/RCAN_OTF_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_RCAN_OTF_finetune
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/RCAN/RCAN_OTF_fromscratch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_RCAN_OTF_fromscratch
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/RCAN/RCAN_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_RCAN
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
2 changes: 1 addition & 1 deletion options/train/RCAN/RCAN_fromscratch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name: 4x_RCAN
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
amp_bf16: true # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
fast_matmul: false # Trade precision for performance.
num_gpu: auto
Expand Down
16 changes: 12 additions & 4 deletions test_scripts/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def format_extra_params(extra_arch_params: dict[str, Any]) -> str:
def compare_precision(
net: nn.Module, input_tensor: Tensor, criterion: nn.Module
) -> tuple[float, float]:
with torch.no_grad():
with torch.inference_mode():
fp32_output = net(input_tensor)

fp16_loss = None
try:
with autocast(dtype=torch.float16, device_type="cuda"):
with autocast(dtype=torch.float16, device_type="cuda"), torch.inference_mode():
fp16_output = net(input_tensor)
fp16_loss = criterion(fp16_output.float(), fp32_output).item()
except Exception as e:
Expand All @@ -66,7 +66,7 @@ def compare_precision(

bf16_loss = None
try:
with autocast(dtype=torch.bfloat16, device_type="cuda"):
with autocast(dtype=torch.bfloat16, device_type="cuda"), torch.inference_mode():
bf16_output = net(input_tensor)
bf16_loss = criterion(bf16_output.float(), fp32_output).item()
except Exception as e:
Expand All @@ -82,7 +82,15 @@ def compare_precision(
label = f"{name} {format_extra_params(extra_arch_params)} {scale}x"

try:
if "realplksr" not in name:
if name not in {
"rcan",
"esrgan",
"compact",
"span",
"dat_2",
"spanplus",
"realplksr",
}:
continue

net: nn.Module = arch(scale=scale, **extra_arch_params).eval().to("cuda")
Expand Down
2 changes: 2 additions & 0 deletions traiNNer/archs/arch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"hit_sir",
"hit_sng",
"hit_srf",
"rcan",
"moesr2",
"rgt_s",
"rgt",
"seemore_t",
Expand Down
19 changes: 19 additions & 0 deletions traiNNer/archs/artcnn_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,22 @@ def artcnn_r8f64(
kernel_size=kernel_size,
act=act,
)


@ARCH_REGISTRY.register()
def artcnn_r8f48(
in_ch: int = 3,
scale: int = 4,
filters: int = 48,
n_block: int = 8,
kernel_size: int = 3,
act: type[nn.Module] = nn.ReLU,
) -> ArtCNN:
return ArtCNN(
scale=scale,
in_ch=in_ch,
n_block=n_block,
filters=filters,
kernel_size=kernel_size,
act=act,
)

0 comments on commit 5f8e6d4

Please sign in to comment.