Skip to content

Commit

Permalink
Improved restore train session, and other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
voldien committed Feb 29, 2024
1 parent ced1b6d commit 3ce2736
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
16 changes: 7 additions & 9 deletions superresolution/SuperResolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
# Setup none-augmented version for presentation.
non_augmented_dataset_train = dataset_super_resolution(dataset=training_dataset,
input_size=image_input_size,
output_size=image_output_size)
output_size=image_output_size, crop=False)
non_augmented_dataset_validation = None
if validation_dataset:
non_augmented_dataset_validation = dataset_super_resolution(dataset=validation_dataset,
Expand Down Expand Up @@ -299,6 +299,11 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
# checkpoint root_path
checkpoint_root_path: str = args.checkpoint_dir

# TODO: improve
if os.path.exists(checkpoint_root_path):
custom_objects = {'PSNRMetric' : PSNRMetric(), 'VGG16Error' : VGG16Error()}
training_model = tf.keras.models.load_model(checkpoint_root_path, custom_objects=custom_objects)

# Create a callback that saves the model weights
checkpoint_path = os.path.join(checkpoint_root_path, "cpkt-{epoch:02d}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
Expand All @@ -308,12 +313,6 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
save_freq='epoch',
verbose=1)

#
checkpoint = tf.train.Checkpoint(optimizer=model_optimizer, model=training_model)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_root_path)
if latest_checkpoint:
status = checkpoint.restore(save_path=latest_checkpoint).assert_consumed()

training_callbacks: list = [tf.keras.callbacks.TerminateOnNaN(), checkpoint_callback]

example_result_call_back = SaveExampleResultImageCallBack(
Expand All @@ -322,7 +321,7 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
nth_batch_sample=args.example_nth_batch, grid_size=args.example_nth_batch_grid_size)
training_callbacks.append(example_result_call_back)

# Debug output of trained data.
# Debug output of the trained augmented data.
#training_callbacks.append(SaveExampleResultImageCallBack(
# args.output_dir,
# training_dataset, args.color_space, fileprefix="trainSuperResolution",
Expand Down Expand Up @@ -352,7 +351,6 @@ def run_train_model(args: dict, training_dataset: Dataset, validation_dataset: D
epochs=args.epochs,
callbacks=training_callbacks)
#
# training_model.load_weights(checkpoint_path)
training_model.save(args.model_filepath)

# Test model.
Expand Down
9 changes: 8 additions & 1 deletion superresolution/util/dataProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def AgumentFunc(x):
return dataset


def dataset_super_resolution(dataset: Dataset, input_size: tuple, output_size: tuple) -> Dataset:
def dataset_super_resolution(dataset: Dataset, input_size: tuple, output_size: tuple, crop: bool = False) -> Dataset:
"""
Perform Super Resolution Data and Expected Data to Correct Size. For providing
the model with corrected sized Data.
Expand Down Expand Up @@ -176,6 +176,13 @@ def DownScaleLayer(data):

return data, expected_data

def resize_data(images):
SIZE = output_size
return tf.image.resize_with_crop_or_pad(images, SIZE[0], SIZE[1])

if crop:
dataset = dataset.map(resize_data)

DownScaledDataSet = (
dataset
.map(DownScaleLayer,
Expand Down
2 changes: 2 additions & 0 deletions superresolution/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from util.trainingcallback import compute_normalized_PSNR


@keras.saving.register_keras_serializable(package="superresolution", name="PSNRMetric")
class PSNRMetric(tf.keras.metrics.MeanMetricWrapper):
def __init__(self, name="PSNR", dtype=None):
def psnr(y_true, y_pred):
Expand All @@ -27,6 +28,7 @@ def psnr(y_true, y_pred):
super().__init__(psnr, name, dtype=dtype)


@keras.saving.register_keras_serializable(package="superresolution", name="VGG16Error")
class VGG16Error(LossFunctionWrapper):
selected_layers = ['block1_conv1', 'block2_conv2',
"block3_conv3", 'block4_conv3', 'block5_conv4']
Expand Down

0 comments on commit 3ce2736

Please sign in to comment.