forked from yinguobing/arcface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
65 lines (51 loc) · 2.34 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""This module provides the dataset parsing function to generate the training
and testing data."""
import tensorflow as tf
from preprocessing import normalize
def build_dataset(tfrecord_file,
batch_size,
one_hot_depth,
training=False,
buffer_size=4096):
"""Generate parsed TensorFlow dataset.
Args:
tfrecord_file: the tfrecord file path.
batch_size: batch size.
one_hot_depth: the depth for one hot encoding, usually the number of
classes.
training: a boolean indicating whether the dataset will be used for
training.
buffer_size: hwo large the buffer is for shuffling the samples.
Returns:
a parsed dataset.
"""
# Let TensorFlow tune the input pipeline automatically.
autotune = tf.data.experimental.AUTOTUNE
# Describe how the dataset was constructed. The author who created the file
# is responsible for this information.
feature_description = {'image/height': tf.io.FixedLenFeature([], tf.int64),
'image/width': tf.io.FixedLenFeature([], tf.int64),
'image/depth': tf.io.FixedLenFeature([], tf.int64),
'image/encoded': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)}
# Define a helper function to decode the tf-example. This function will be
# called by map() later.
def _parse_function(example):
features = tf.io.parse_single_example(example, feature_description)
image = tf.image.decode_jpeg(features['image/encoded'])
image = normalize(image)
label = tf.one_hot(features['label'], depth=one_hot_depth,
dtype=tf.float32)
return image, label
# Now construct the dataset from tfrecord file and make it indefinite.
dataset = tf.data.TFRecordDataset(tfrecord_file)
# Shuffle the data if training.
if training:
dataset = dataset.shuffle(buffer_size)
# Parse the dataset to get samples.
dataset = dataset.map(_parse_function, num_parallel_calls=autotune)
# Batch the data.
dataset = dataset.batch(batch_size, drop_remainder=True)
# Prefetch the data to accelerate the pipeline.
dataset = dataset.prefetch(autotune)
return dataset