diff --git a/config.py b/config.py index 93d2ad8..7cce5bb 100644 --- a/config.py +++ b/config.py @@ -41,6 +41,5 @@ TASK_LR=1e-2 # Loading tfrecord and saving paths -# TFRECORD_PATH='train_MZSR.tfrecord' -TFRECORD_PATH='../train_MLZSSR_293848.tfrecord' +TFRECORD_PATH='train_SR_MZSR.tfrecord' CHECKPOINT_DIR='SR' diff --git a/dataGenerator.py b/dataGenerator.py index 0ec4e8c..ed22e91 100644 --- a/dataGenerator.py +++ b/dataGenerator.py @@ -51,11 +51,11 @@ def make_data_tensor(self, sess, scale_list, noise_std=0.0): '''Load TFRECORD''' def _parse_function(self, example_proto): - keys_to_features = {'image': tf.FixedLenFeature([], tf.string)} + keys_to_features = {'label': tf.FixedLenFeature([], tf.string)} parsed_features = tf.parse_single_example(example_proto, keys_to_features) - img = parsed_features['image'] + img = parsed_features['label'] img = tf.divide(tf.cast(tf.decode_raw(img, tf.uint8), tf.float32), 255.) img = tf.reshape(img, [self.HEIGHT, self.WIDTH, self.CHANNEL])