Skip to content

Commit

Permalink
Started adding SRGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
voldien committed Apr 6, 2024
1 parent b793a7e commit 5855632
Show file tree
Hide file tree
Showing 9 changed files with 642 additions and 123 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ python superresolution/SuperResolution.py --batch-size 16 --epochs 10 --image-s
![Gangsta Anime EDSR Super Resolution Example from Trained model](https://github.com/voldien/SuperResolution/assets/9608088/24cccb38-807f-4454-bbc6-35ad9e03b57f)
![Amagi Brilliant Park Anime EDSR Super Resolution Example from Trained model](https://github.com/voldien/SuperResolution/assets/9608088/153792f5-c35a-4fae-8bba-aed47c8902de)


### GAN - Generative Adversarial Network - Super Resolution

```bash

```

$
l^{SR} = l_X^{SR} + 10^-3l^{SR}_{GEN}
$

### AE - AutoEncoder Super Resolution

```bash
Expand Down
150 changes: 76 additions & 74 deletions superresolution/SuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from random import randrange
from typing import Dict

from util.loss import SSIMError, VGG16Error, psnr_loss
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_io as tfio
Expand All @@ -27,12 +28,13 @@
import models.SuperResolutionResNet
import models.SuperResolutionVDSR
import models.SuperResolutionCNN
import models.SuperResolutionGAN

from core.common import ParseDefaultArgument, DefaultArgumentParser, setup_tensorflow_strategy
from util.dataProcessing import load_dataset_from_directory, \
configure_dataset_performance, dataset_super_resolution, augment_dataset
from util.metrics import PSNRMetric, VGG16Error
from util.trainingcallback import GraphHistory, SaveExampleResultImageCallBack, compute_normalized_PSNR, \
from util.metrics import PSNRMetric
from util.trainingcallback import GraphHistory, SaveExampleResultImageCallBack, \
CompositeImageResultCallBack
from util.util import plotTrainingHistory

Expand Down Expand Up @@ -83,7 +85,8 @@ def load_dataset_collection(filepaths: list, args: dict, override_size: tuple) -
data_dir = pathlib.Path(directory_path)
logging.info("Loading dataset directory {0}".format(data_dir))

local_dataset = load_dataset_from_directory(data_path=data_dir, args=args, override_size=override_size)
local_dataset = load_dataset_from_directory(
data_path=data_dir, args=args, override_size=override_size)
if not training_dataset:
training_dataset = local_dataset
else:
Expand All @@ -102,6 +105,7 @@ def load_builtin_model_interfaces() -> Dict[str, ModelBase]:
builtin_models['dcsr-resnet'] = models.SuperResolutionResNet.get_model_interface()
builtin_models['vdsr'] = models.SuperResolutionVDSR.get_model_interface()
builtin_models['cnnsr'] = models.SuperResolutionCNN.get_model_interface()
builtin_models['gan'] = models.SuperResolutionGAN.get_model_interface()

return builtin_models

Expand Down Expand Up @@ -140,34 +144,8 @@ def setup_model(args: dict, builtin_models: Dict[str, ModelBase], image_input_si


def setup_loss_builtin_function(args: dict):
def ssim_loss(y_true, y_pred):
# TODO convert color space.
y_true_color = None
y_pred_color = None

#
if args.color_space == 'rgb':
# Remap [-1,1] to [0,1]
y_true_color = ((y_true + 1.0) * 0.5)
y_pred_color = ((y_pred + 1.0) * 0.5)
elif args.color_space == 'lab':
# Remap [-1,1] -> [-128, 128] -> [0,1]
y_true_color = tfio.experimental.color.lab_to_rgb(y_true * 128)
y_pred_color = tfio.experimental.color.lab_to_rgb(y_pred * 128)
else:
assert 0

return (1 - tf.reduce_mean(tf.image.ssim(y_true_color, y_pred_color, max_val=1.0, filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)))

def psnr_loss(y_true, y_pred): # TODO: fix equation.
return 20.0 - compute_normalized_PSNR(y_true, y_pred)

def total_variation_loss(y_true, y_pred): # TODO: fix equation.
return 1.0 - tf.reduce_sum(tf.image.total_variation(y_true, y_pred))

#
builtin_loss_functions = {'mse': tf.keras.losses.MeanSquaredError(), 'ssim': ssim_loss,
builtin_loss_functions = {'mse': tf.keras.losses.MeanSquaredError(), 'ssim': SSIMError(color_space=args.color_space),
'msa': tf.keras.losses.MeanAbsoluteError(), 'psnr': psnr_loss, 'vgg16': VGG16Error()}

return builtin_loss_functions[args.loss_fn]
Expand All @@ -177,7 +155,8 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
test_dataset: Dataset = None):
# Configure how models will be executed.
strategy = setup_tensorflow_strategy(args=args)
sr_logger.info('Number of devices: {0}'.format(strategy.num_replicas_in_sync))
sr_logger.info('Number of devices: {0}'.format(
strategy.num_replicas_in_sync))

# Compute the total batch size.
batch_size: int = args.batch_size * strategy.num_replicas_in_sync
Expand All @@ -203,11 +182,13 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
non_augmented_dataset_validation = dataset_super_resolution(dataset=validation_dataset,
input_size=image_input_size,
output_size=image_output_size)
non_augmented_dataset_validation = non_augmented_dataset_validation.batch(batch_size)
non_augmented_dataset_validation = non_augmented_dataset_validation.batch(
batch_size)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
non_augmented_dataset_validation = non_augmented_dataset_validation.with_options(options)
non_augmented_dataset_validation = non_augmented_dataset_validation.with_options(
options)

non_augmented_dataset_train = configure_dataset_performance(ds=non_augmented_dataset_train, use_cache=False,
cache_path=None, shuffle_size=0)
Expand All @@ -216,15 +197,17 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
non_augmented_dataset_train = non_augmented_dataset_train.with_options(options)
non_augmented_dataset_train = non_augmented_dataset_train.with_options(
options)

# Configure cache, shuffle and performance of the dataset.
training_dataset = configure_dataset_performance(ds=training_dataset, use_cache=args.cache_ram,
cache_path=args.cache_path,
shuffle_size=args.dataset_shuffle_size)

# Apply data augmentation
training_dataset = augment_dataset(dataset=training_dataset, image_crop_shape=image_output_size)
training_dataset = augment_dataset(
dataset=training_dataset, image_crop_shape=image_output_size)

# Transform data to fit upscale.
training_dataset = dataset_super_resolution(dataset=training_dataset,
Expand All @@ -247,7 +230,8 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
cache_path=None,
shuffle_size=0)
# Apply data augmentation
validation_data_ds = augment_dataset(dataset=validation_data_ds, image_crop_shape=image_output_size)
validation_data_ds = augment_dataset(
dataset=validation_data_ds, image_crop_shape=image_output_size)

# Transform data to fit upscale.
validation_data_ds = dataset_super_resolution(dataset=validation_data_ds,
Expand All @@ -274,14 +258,6 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
training_model = setup_model(args=args, builtin_models=builtin_models, image_input_size=image_input_size,
image_output_size=image_output_size)

# Save the model as an image to directory, for easy backtracking of the model composition.
tf.keras.utils.plot_model(
training_model, to_file=os.path.join(args.output_dir, 'Model.png'),
show_shapes=True, show_dtype=True,
show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96,
layer_range=None
)

sr_logger.debug(training_model.summary())

# NOTE currently, only support checkpoint if generated model and not when using existing.
Expand All @@ -292,26 +268,41 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
if args.show_psnr:
metrics.append(PSNRMetric())

training_model.compile(optimizer=model_optimizer, loss=loss_fn, metrics=metrics)
training_model.compile(optimizer=model_optimizer,
loss=loss_fn, metrics=metrics)

# Save the model as an image to directory, for easy backtracking of the model composition.
tf.keras.utils.plot_model(
training_model, to_file=os.path.join(args.output_dir, 'Model.png'),
show_shapes=True, show_dtype=True,
show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96,
layer_range=None
)

# checkpoint root_path
checkpoint_root_path: str = args.checkpoint_dir

# TODO: improve
if os.path.exists(checkpoint_root_path):
custom_objects = {'PSNRMetric': PSNRMetric(), 'VGG16Error': VGG16Error()}
training_model = tf.keras.models.load_model(checkpoint_root_path, custom_objects=custom_objects)

# Create a callback that saves the model weights
checkpoint_path = os.path.join(checkpoint_root_path, "cpkt-{epoch:02d}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor='val_loss' if validation_data_ds else 'loss',
mode='min',
save_freq='epoch',
verbose=1)

training_callbacks: list = [tf.keras.callbacks.TerminateOnNaN(), checkpoint_callback]
custom_objects = {
'PSNRMetric': PSNRMetric(), 'VGG16Error': VGG16Error()}
training_model = tf.keras.models.load_model(
checkpoint_root_path, custom_objects=custom_objects)

training_callbacks: list = [tf.keras.callbacks.TerminateOnNaN()]

if args.use_checkpoint:
# Create a callback that saves the model weights
checkpoint_path = os.path.join(
checkpoint_root_path, "cpkt-{epoch:02d}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor='val_loss' if validation_data_ds else 'loss',
mode='min',
save_freq='epoch',
verbose=1)

training_callbacks.append(checkpoint_callback)

example_result_call_back = SaveExampleResultImageCallBack(
args.output_dir,
Expand All @@ -321,9 +312,9 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D

# Debug output of the trained augmented data.
# training_callbacks.append(SaveExampleResultImageCallBack(
# args.output_dir,
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
# nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size))
# args.output_dir,
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
# nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size))

composite_train_callback = CompositeImageResultCallBack(
dir_path=args.output_dir,
Expand All @@ -339,7 +330,8 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
train_data_subset=non_augmented_dataset_validation, color_space=args.color_space)
training_callbacks.append(composite_validation_callback)

graph_output_filepath: str = os.path.join(args.output_dir, "history_graph.png")
graph_output_filepath: str = os.path.join(
args.output_dir, "history_graph.png")
training_callbacks.append(GraphHistory(filepath=graph_output_filepath))

# Save copy.
Expand Down Expand Up @@ -397,7 +389,8 @@ def dcsuperresolution_program(vargs=None):
help='Define file path that the generator model will be saved at.')
#
parser.add_argument('--output-dir', type=str, dest='output_dir',
default=str.format("super-resolution-{0}", date.today().strftime("%b-%d-%Y_%H:%M:%S")),
default=str.format(
"super-resolution-{0}", date.today().strftime("%b-%d-%Y_%H:%M:%S")),
help='Set the output directory that all the models and results will be stored at')
#
parser.add_argument('--example-batch', dest='example_nth_batch', required=False, # TODO rename
Expand All @@ -411,7 +404,7 @@ def dcsuperresolution_program(vargs=None):

#
parser.add_argument('--show-psnr', dest='show_psnr', action='store_true',
default=False, help='Set the grid size of number of example images.')
default=False, help='Set the grid size of number of example images.')

# TODO add support
parser.add_argument('--metrics', dest='metrics',
Expand All @@ -430,12 +423,13 @@ def dcsuperresolution_program(vargs=None):
parser.add_argument('--model', dest='model',
default='dcsr',
choices=['dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
'vdsr'],
'vdsr', 'gan'],
help='Set which model type to use.', type=str)
#
parser.add_argument('--loss-fn', dest='loss_fn',
default='mse',
choices=['mse', 'ssim', 'msa', 'psnr', 'vgg16', 'none'],
choices=['mse', 'ssim', 'msa',
'psnr', 'vgg16', 'none'],
help='Set Loss Function', type=str)

# If invalid number of arguments, print help.
Expand All @@ -455,20 +449,24 @@ def dcsuperresolution_program(vargs=None):
# Override the default logging level
sr_logger.setLevel(args.verbosity)
# Add logging output path.
sr_logger.addHandler(logging.FileHandler(filename=os.path.join(args.output_dir, "log.txt")))
sr_logger.addHandler(logging.FileHandler(
filename=os.path.join(args.output_dir, "log.txt")))

# Logging about all the options etc.
sr_logger.info(str.format("Epochs: {0}", args.epochs))
sr_logger.info(str.format("Batch Size: {0}", args.batch_size))

sr_logger.info(str.format("Use float16: {0}", args.use_float16))

sr_logger.info(str.format("CheckPoint Save Every Nth Epoch: {0}", args.checkpoint_every_nth_epoch))
sr_logger.info(str.format(
"CheckPoint Save Every Nth Epoch: {0}", args.checkpoint_every_nth_epoch))

sr_logger.info(str.format("Use RAM Cache: {0}", args.cache_ram))

sr_logger.info(str.format("Example Batch Grid Size: {0}", args.example_nth_batch_grid_size))
sr_logger.info(str.format("Image Training Set: {0}", args.input_image_size))
sr_logger.info(str.format(
"Example Batch Grid Size: {0}", args.example_nth_batch_grid_size))
sr_logger.info(str.format(
"Image Training Set: {0}", args.input_image_size))
sr_logger.info(str.format("Learning Rate: {0}", args.learning_rate))
sr_logger.info(str.format(
"Learning Decay Rate: {0}", args.learning_rate_decay))
Expand All @@ -479,7 +477,8 @@ def dcsuperresolution_program(vargs=None):

# Create absolute path for model file, if relative path.
if not os.path.isabs(args.model_filepath):
args.model_filepath = os.path.join(args.output_dir, args.model_filepath)
args.model_filepath = os.path.join(
args.output_dir, args.model_filepath)

# Allow override to enable cropping for increase details in the dataset.
override_size: tuple = (768, 768) # TODO fix.
Expand All @@ -500,11 +499,13 @@ def dcsuperresolution_program(vargs=None):
test_dataset = None
test_set_filepaths = args.test_directory_paths
if test_set_filepaths:
test_dataset = load_dataset_collection(filepaths=test_set_filepaths, args=args, override_size=override_size)
test_dataset = load_dataset_collection(
filepaths=test_set_filepaths, args=args, override_size=override_size)

if not training_dataset:
sr_logger.error("Failed to construct dataset")
raise RuntimeError("Could not create dataset from {0}".format(data_set_filepaths))
raise RuntimeError(
"Could not create dataset from {0}".format(data_set_filepaths))

# Make a copy of the command line.
commandline = ' '.join(vargs)
Expand All @@ -518,7 +519,8 @@ def dcsuperresolution_program(vargs=None):
json.dump(args.__dict__, writefile, indent=2)

# Main Train Model
run_train_model(args, training_dataset, validation_dataset, test_dataset)
run_train_model(args, training_dataset,
validation_dataset, test_dataset)

except Exception as ex:
print(ex)
Expand Down
5 changes: 5 additions & 0 deletions superresolution/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def DefaultArgumentParser() -> argparse.ArgumentParser:

parser.add_argument('--batch-size', type=int, default=16, dest='batch_size',
help='number of training element per each batch, during training.')


#
parser.add_argument('--use-checkpoint', dest='use_checkpoint', action='store_true',
help='Set the path the checkpoint will be saved/loaded.')
#
parser.add_argument('--checkpoint-filepath', type=str, dest='checkpoint_dir',
default="training_checkpoints",
Expand Down
Loading

0 comments on commit 5855632

Please sign in to comment.