Skip to content

Commit

Permalink
Improved VGG loss functions, Improved SRGAN, and other changes and im…
Browse files Browse the repository at this point in the history
…provements
  • Loading branch information
voldien committed Apr 9, 2024
1 parent 5855632 commit 20cbc88
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 150 deletions.
24 changes: 8 additions & 16 deletions superresolution/SuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
from logging import Logger
from random import randrange
from typing import Dict

from util.loss import SSIMError, VGG16Error, psnr_loss
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_io as tfio
from matplotlib import pyplot as plt
from tensorflow.python.data import Dataset

Expand All @@ -28,7 +25,7 @@
import models.SuperResolutionResNet
import models.SuperResolutionVDSR
import models.SuperResolutionCNN
import models.SuperResolutionGAN
import models.SuperResolutionSRGAN

from core.common import ParseDefaultArgument, DefaultArgumentParser, setup_tensorflow_strategy
from util.dataProcessing import load_dataset_from_directory, \
Expand All @@ -37,6 +34,7 @@
from util.trainingcallback import GraphHistory, SaveExampleResultImageCallBack, \
CompositeImageResultCallBack
from util.util import plotTrainingHistory
from util.loss import SSIMError, VGG16Error, VGG19Error, psnr_loss

global sr_logger
sr_logger: Logger = logging.getLogger("Super Resolution Training")
Expand Down Expand Up @@ -105,7 +103,7 @@ def load_builtin_model_interfaces() -> Dict[str, ModelBase]:
builtin_models['dcsr-resnet'] = models.SuperResolutionResNet.get_model_interface()
builtin_models['vdsr'] = models.SuperResolutionVDSR.get_model_interface()
builtin_models['cnnsr'] = models.SuperResolutionCNN.get_model_interface()
builtin_models['gan'] = models.SuperResolutionGAN.get_model_interface()
builtin_models['srgan'] = models.SuperResolutionSRGAN.get_model_interface()

return builtin_models

Expand Down Expand Up @@ -146,7 +144,7 @@ def setup_model(args: dict, builtin_models: Dict[str, ModelBase], image_input_si
def setup_loss_builtin_function(args: dict):
#
builtin_loss_functions = {'mse': tf.keras.losses.MeanSquaredError(), 'ssim': SSIMError(color_space=args.color_space),
'msa': tf.keras.losses.MeanAbsoluteError(), 'psnr': psnr_loss, 'vgg16': VGG16Error()}
'msa': tf.keras.losses.MeanAbsoluteError(), 'psnr': psnr_loss, 'vgg16': VGG16Error(), 'vgg16': VGG19Error()}

return builtin_loss_functions[args.loss_fn]

Expand Down Expand Up @@ -264,7 +262,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
loss_fn = setup_loss_builtin_function(args)

# TODO metric list.
metrics = ['accuracy', ]
metrics = [tf.keras.metrics.Accuracy(), ]
if args.show_psnr:
metrics.append(PSNRMetric())

Expand Down Expand Up @@ -310,12 +308,6 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size)
training_callbacks.append(example_result_call_back)

# Debug output of the trained augmented data.
# training_callbacks.append(SaveExampleResultImageCallBack(
# args.output_dir,
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
# nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size))

composite_train_callback = CompositeImageResultCallBack(
dir_path=args.output_dir,
name="train",
Expand Down Expand Up @@ -404,7 +396,7 @@ def dcsuperresolution_program(vargs=None):

#
parser.add_argument('--show-psnr', dest='show_psnr', action='store_true',
default=False, help='Set the grid size of number of example images.')
default=False, help='Set the grid size of number of example images.')

# TODO add support
parser.add_argument('--metrics', dest='metrics',
Expand All @@ -423,13 +415,13 @@ def dcsuperresolution_program(vargs=None):
parser.add_argument('--model', dest='model',
default='dcsr',
choices=['dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
'vdsr', 'gan'],
'vdsr', 'srgan'],
help='Set which model type to use.', type=str)
#
parser.add_argument('--loss-fn', dest='loss_fn',
default='mse',
choices=['mse', 'ssim', 'msa',
'psnr', 'vgg16', 'none'],
'psnr', 'vgg16', 'vgg19', 'none'],
help='Set Loss Function', type=str)

# If invalid number of arguments, print help.
Expand Down
4 changes: 4 additions & 0 deletions superresolution/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def DefaultArgumentParser() -> argparse.ArgumentParser:

parser.add_argument('--disable-validation', default=True, dest='use_validation', action='store_false',
help='Select if use data validation step.')

parser.add_argument('--config', default=None, dest='config',
help='Config File - Json.')

return parser


Expand Down
Loading

0 comments on commit 20cbc88

Please sign in to comment.