Skip to content

Commit

Permalink
Added upscale for 2,3,4 times and other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
voldien committed Mar 10, 2024
1 parent 1c05032 commit ecb0cda
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ docker run --network=host --gpus all --name sr-cuda super-resolution-cuda
```
# Convert
```
python3 superresolution/generate_tflite.py --model super-resolution-model.keras --output model-lite.tflite
```
## License
This project is licensed under the GPL+3 License - see the [LICENSE](LICENSE) file for details.
42 changes: 42 additions & 0 deletions superresolution/generate_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# !/usr/bin/env python3
import argparse
import logging
import sys
import tensorflow as tf
from util.convert_model import convert_model


def convert_model_program(argv=None):
parser = argparse.ArgumentParser(
description='Tensor Lite Converter')

parser.add_argument('--model-file', dest='model_filepath',
type=str,
required=True,
help='Define filepath to the model')

parser.add_argument('--output', dest='save_path', type=str, default=None, help='')

parser.add_argument('--verbosity', type=int, dest='accumulate',
default=1,
help='Define the save/load model path')

args = parser.parse_args(args=argv)

logger = logging.getLogger('model converter')
logger.setLevel(logging.INFO)

with tf.device('/device:CPU:0'):
generate_model_path = args.model_filepath

logger.info("Loading Model: {0}".format(generate_model_path))
model = tf.keras.models.load_model(generate_model_path, compile=False)

converted_model = convert_model(model=model, dataset=None)
with open(args.save_path, "wb") as f:
f.write(converted_model)


# If running the script as main executable
if __name__ == '__main__':
convert_model_program(sys.argv[1:])
18 changes: 12 additions & 6 deletions superresolution/models/SuperResolutionEDSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

class EDSRSuperResolutionModel(ModelBase):
def __init__(self):
self.possible_upscale = [2, 3, 4]

self.parser = argparse.ArgumentParser(add_help=False)

#
Expand All @@ -20,13 +22,10 @@ def __init__(self):
#
self.parser.add_argument('--upscale-mode', dest='upscale_mode',
type=int,
choices=[2, 4],
choices=self.possible_upscale,
default=2,
required=False,
help='Upscale Mode')
#
self.parser.add_argument('--use-resnet', type=bool, default=False, dest='use_resnet',
help='Set the number of passes that the training set will be trained against.')

#
self.parser.add_argument('--edsr-filters', type=int, default=256, dest='edsr_filters',
Expand All @@ -37,10 +36,16 @@ def load_argument(self) -> argparse.ArgumentParser:
return self.parser

def create_model(self, input_shape, output_shape, **kwargs) -> keras.Model:
scale_width: int = int(output_shape[0] / input_shape[0])
scale_height: int = int(output_shape[1] / input_shape[1])

if scale_width not in self.possible_upscale and scale_height not in self.possible_upscale:
raise ValueError("Ivalid upscale")

# Model constructor parameters.
regularization: float = kwargs.get("regularization", 0.00001) #
upscale_mode: int = kwargs.get("upscale_mode", 2) #
num_input_filters: int = kwargs.get("edsr_filters", 192) #
upscale_mode: int = scale_width #
num_input_filters: int = kwargs.get("edsr_filters", 256) #

#
return create_edsr_model(input_shape=input_shape,
Expand Down Expand Up @@ -68,6 +73,7 @@ def create_edsr_model(input_shape: tuple, output_shape: tuple, scale: int, num_f
x = _res_block = layers.Conv2D(filters=num_filters, kernel_size=(3, 3), padding='same',
kernel_initializer=tf.keras.initializers.GlorotUniform())(
x_in)
# Residual blocks.
for _ in range(num_res_blocks):
_res_block = res_block(_res_block, num_filters, res_block_scaling)

Expand Down
3 changes: 2 additions & 1 deletion superresolution/util/convert_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf


def convert_model(model, dataset=None):

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.float32
Expand All @@ -14,6 +14,7 @@ def convert_model(model, dataset=None):
]

converter.post_training_quantize = True

if dataset:
converter.representative_dataset = tf.lite.RepresentativeDataset(
dataset)
Expand Down

0 comments on commit ecb0cda

Please sign in to comment.