-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdecompress.py
48 lines (33 loc) · 1.38 KB
/
decompress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import tensorflow as tf
import tensorflow_compression as tfc
import argparse
from glob import glob
def load_model(args):
model = tf.keras.models.load_model(args.model_path,compile=False)
return model
def decompress(model, args):
os.makedirs("outputs/reconstruction/", exist_ok=True)
if os.path.isdir(args.binary_path):
pathes = glob(os.path.join(args.binary_path, '*'))
else:
pathes = [args.binary_path]
for path in pathes:
print('========================================================================')
print('image', os.path.basename(path))
with open(path, "rb") as f:
packed = tfc.PackedTensors(f.read())
tensors = packed.unpack(dtypes)
x_hat = model.decompress(*tensors)
fakepath = "./outputs/reconstruction/{}.png".format(os.path.basename(path).split('.')[0])
string = tf.image.encode_png(x_hat)
tf.io.write_file(fakepath, string)
print('========================================================================\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model_path')
parser.add_argument('binary_path')
args = parser.parse_args()
model = load_model(args)
dtypes = [t.dtype for t in model.decompress.input_signature]
decompress(model, args)