Skip to content

Commit

Permalink
Added 4x upscale on basic super resolution model and other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
voldien committed Apr 5, 2024
1 parent aab6e75 commit b793a7e
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 39 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ optional arguments:
## Installation Instructions
### Setup Virtual Environment
python3.9 or higher
```bash
python3 -m venv venv
source venv/bin/activate
```
## Installing Required Packages
Expand Down
8 changes: 4 additions & 4 deletions superresolution/SuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D

# TODO: improve
if os.path.exists(checkpoint_root_path):
custom_objects = {'PSNRMetric' : PSNRMetric(), 'VGG16Error' : VGG16Error()}
custom_objects = {'PSNRMetric': PSNRMetric(), 'VGG16Error': VGG16Error()}
training_model = tf.keras.models.load_model(checkpoint_root_path, custom_objects=custom_objects)

# Create a callback that saves the model weights
Expand All @@ -320,7 +320,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
training_callbacks.append(example_result_call_back)

# Debug output of the trained augmented data.
#training_callbacks.append(SaveExampleResultImageCallBack(
# training_callbacks.append(SaveExampleResultImageCallBack(
# args.output_dir,
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
# nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size))
Expand Down Expand Up @@ -429,7 +429,7 @@ def dcsuperresolution_program(vargs=None):
#
parser.add_argument('--model', dest='model',
default='dcsr',
choices=['cnnsr', 'dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
choices=['dcsr', 'dscr-post', 'dscr-pre', 'edsr', 'dcsr-ae', 'dcsr-resnet',
'vdsr'],
help='Set which model type to use.', type=str)
#
Expand Down Expand Up @@ -482,7 +482,7 @@ def dcsuperresolution_program(vargs=None):
args.model_filepath = os.path.join(args.output_dir, args.model_filepath)

# Allow override to enable cropping for increase details in the dataset.
override_size: tuple = (512, 512) # TODO fix.
override_size: tuple = (768, 768) # TODO fix.

# Setup Dataset
training_dataset = None
Expand Down
9 changes: 8 additions & 1 deletion superresolution/models/DCSuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class DCSuperResolutionModel(ModelBase):
def __init__(self):
self.possible_upscale = [2, 4]
self.parser = argparse.ArgumentParser(add_help=False, prog="Basic SuperResolution",
description="Basic Deep Convolutional Super Resolution")
group = self.parser.add_argument_group(self.get_name())
Expand All @@ -31,9 +32,15 @@ def load_argument(self) -> argparse.ArgumentParser:
return self.parser

def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
scale_factor: int = int(output_shape[0] / input_shape[0])
scale_factor: int = int(output_shape[1] / input_shape[1])

if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
raise ValueError("Invalid upscale")

# Model Construct Parameters.
regularization: float = kwargs.get("regularization", 0.000001) #
upscale_mode: int = kwargs.get("upscale_mode", 2) #
upscale_mode: int = scale_factor
nr_filters: int = kwargs.get("filters", 64)

#
Expand Down
35 changes: 21 additions & 14 deletions superresolution/models/SuperResolutionCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

class SuperResolutionModelCNN(ModelBase):
def __init__(self):
self.parser = argparse.ArgumentParser(add_help=False) # , prog="Basic SuperResolution",
# description="Basic Deep Convolutional Super Resolution")
#
self.possible_upscale = [2, 4]
self.parser = argparse.ArgumentParser(add_help=False)

self.parser.add_argument('--regularization', dest='regularization',
type=float,
default=0.000001,
Expand All @@ -30,13 +30,20 @@ def load_argument(self) -> argparse.ArgumentParser:
return self.parser

def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
scale_factor: int = int(output_shape[0] / input_shape[0])
scale_factor: int = int(output_shape[1] / input_shape[1])

if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
raise ValueError("Invalid upscale")

# Model Construct Parameters.
regularization: float = kwargs.get("regularization", 0.000001) #
upscale_mode: int = kwargs.get("upscale_mode", 2) #
upscale_mode: int = scale_factor #
num_input_filters: int = kwargs.get("input_filters", 64) #

#
return create_cnn_model(input_shape=input_shape,
output_shape=output_shape, input_filter_size=64, regularization=regularization,
output_shape=output_shape, input_filter_size=num_input_filters, regularization=regularization,
upscale_mode=upscale_mode,
kernel_activation='relu')

Expand All @@ -51,6 +58,7 @@ def get_model_interface() -> ModelBase:
def create_cnn_model(input_shape: tuple, output_shape: tuple, input_filter_size: int, regularization: float,
upscale_mode: int,
kernel_activation: str):

use_batch_norm: bool = True
use_bias: bool = True
num_conv_block: int = 3
Expand All @@ -63,19 +71,18 @@ def create_cnn_model(input_shape: tuple, output_shape: tuple, input_filter_size:
x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)

# Convolutional block
for i in range(0, int(upscale_mode / 2)):
for _ in range(0, num_conv_block):
filter_size = input_filter_size << i
x = layers.Conv2D(filters=filter_size, kernel_size=(3, 3), strides=1, padding='same', use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)
for _ in range(0, num_conv_block):
filter_size = input_filter_size << i
x = layers.Conv2D(filters=filter_size, kernel_size=(3, 3), strides=1, padding='same', use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal())(x)
if use_batch_norm:
x = layers.BatchNormalization(dtype='float32')(x)
x = create_activation(kernel_activation)(x)

# Output to 3 channel output.
x = layers.Conv2DTranspose(filters=output_channels, kernel_size=(9, 9), strides=(
1, 1), padding='same', use_bias=use_bias, kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer=tf.keras.initializers.HeNormal())(x)
bias_initializer=tf.keras.initializers.HeNormal())(x)
x = layers.Activation('tanh')(x)
x = layers.ActivityRegularization(l1=regularization, l2=0)(x)

Expand Down
6 changes: 0 additions & 6 deletions superresolution/models/SuperResolutionResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ def __init__(self):
default=0.001,
help='Set the L1 Regularization applied.')

self.parser.add_argument('--upscale-mode', dest='upscale_mode',
type=str,
choices=[''],
default='',
help='Set the L1 Regularization applied.')

def load_argument(self) -> argparse.ArgumentParser:
"""Load in the file for extracting text."""

Expand Down
4 changes: 0 additions & 4 deletions superresolution/models/SuperResolutionVDSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class VDSRSuperResolutionModel(ModelBase):
def __init__(self):
self.possible_upscale = [2, 4]


self.parser = argparse.ArgumentParser(add_help=False)
#
self.parser.add_argument('--regularization', dest='regularization',
Expand All @@ -35,9 +34,6 @@ def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:

if scale_factor not in self.possible_upscale and scale_factor not in self.possible_upscale:
raise ValueError("Invalid upscale")

# parser_result = self.parser.parse_known_args(sys.argv[1:])
# Model constructor parameters.

regularization: float = kwargs.get("regularization", 0.00001) #
upscale_mode: int = scale_factor #
Expand Down
2 changes: 1 addition & 1 deletion superresolution/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tensorflow.keras import layers


def create_activation(activation):
def create_activation(activation: str):
if activation == "leaky_relu":
return layers.LeakyReLU(alpha=0.2, dtype='float32')
elif activation == "relu":
Expand Down
4 changes: 2 additions & 2 deletions superresolution/util/convert_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf

def convert_model(model, dataset=None):

def convert_model(model, dataset=None):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.float32
Expand All @@ -14,7 +14,7 @@ def convert_model(model, dataset=None):
]

converter.post_training_quantize = True

if dataset:
converter.representative_dataset = tf.lite.RepresentativeDataset(
dataset)
Expand Down
4 changes: 2 additions & 2 deletions superresolution/util/dataProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def DownScaleLayer(data):
interpolation='bilinear',
crop_to_aspect_ratio=False
)])

expectedScale = tf.keras.Sequential([
layers.Resizing(
output_size[0],
Expand Down Expand Up @@ -181,7 +181,7 @@ def resize_data(images):

if crop:
dataset = dataset.map(resize_data)

DownScaledDataSet = (
dataset
.map(DownScaleLayer,
Expand Down
10 changes: 6 additions & 4 deletions superresolution/util/trainingcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def compute_normalized_PSNR(orignal, data):

class SaveExampleResultImageCallBack(tf.keras.callbacks.Callback):

def __init__(self, dir_path, train_data_subset, color_space: str, nth_batch_sample: int = 512, grid_size: int = 6, fileprefix: str = "SuperResolution",
def __init__(self, dir_path, train_data_subset, color_space: str, nth_batch_sample: int = 512, grid_size: int = 6,
fileprefix: str = "SuperResolution",
**kwargs):
super(tf.keras.callbacks.Callback, self).__init__(**kwargs)

Expand All @@ -37,15 +38,16 @@ def on_epoch_begin(self, epoch, logs=None):
def on_epoch_end(self, epoch, logs=None):
fig = show_expect_predicted_result(model=self.model, image_batch_dataset=self.trainSet,
color_space=self.color_space, nr_col=self.grid_size)
fig.savefig(os.path.join(self.dir_path, "{0}{1}.png".format(self.fileprefix,epoch)))
fig.savefig(os.path.join(self.dir_path, "{0}{1}.png".format(self.fileprefix, epoch)))
fig.clf()
plt.close(fig)

def on_train_batch_end(self, batch, logs=None):
if batch % self.nth_batch_sample == 0:
fig = show_expect_predicted_result(model=self.model, image_batch_dataset=self.trainSet,
color_space=self.color_space, nr_col=self.grid_size)
fig.savefig(os.path.join(self.dir_path, "{0}_{1}_{2}.png".format(self.fileprefix, self.current_epoch, batch)))
fig.savefig(
os.path.join(self.dir_path, "{0}_{1}_{2}.png".format(self.fileprefix, self.current_epoch, batch)))
fig.clf()
plt.close(fig)

Expand Down Expand Up @@ -154,7 +156,7 @@ def on_train_batch_end(self, batch, logs=None):
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch=epoch, logs=logs)

#TODO: add file output.
# TODO: add file output.

# Plot detailed
fig = plotTrainingHistory(self.batch_history, x_label="Batches", y_label="value")
Expand Down

0 comments on commit b793a7e

Please sign in to comment.