-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
68 lines (49 loc) · 2.39 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from utils import *
import glob
import model
class Test(object):
def __init__(self,input_path, output_path, model_path, sigma, conf):
self.input_path=input_path
self.output_path=output_path
self.model_path=model_path
self.conf=conf
self.sigma=sigma/255.
def __call__(self):
img_list=np.sort(np.asarray(glob.glob('%s/*.png' % self.input_path)))
print(img_list)
input=tf.placeholder(tf.float32, shape=[None, None, None, 3])
EST=model.Estimator(input,'EST')
sigma_hat=EST.output
MODEL=model.Denoiser(input,sigma_hat,'Denoise1')
output=MODEL.output
saverE=tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='EST'))
count_param('EST')
vars1=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Denoise1')
vars2=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Noise_ENC')
saverM=tf.train.Saver(var_list=vars1+vars2)
count_param('Denoise1')
count_param('Noise_ENC')
count_param()
print('Sigma: ', self.sigma*255.)
with tf.Session(config=self.conf) as sess:
ckpt_modelE = os.path.join(self.model_path, 'EST')
print(ckpt_modelE)
saverE.restore(sess,ckpt_modelE)
ckpt_model = os.path.join(self.model_path, 'AWGN')
print(ckpt_model)
saverM.restore(sess,ckpt_model)
P = []
print('Process %d images' % len(img_list))
for idx, img_path in enumerate(img_list):
img=imread(img_path)
img=img[None,:,:,:]
np.random.seed(0)
noise_img = img + np.random.standard_normal(img.shape)*self.sigma
out=sess.run(output, feed_dict={input:noise_img})
P.append(psnr(img[0]*255., np.clip(np.round(out[0]*255.), 0., 255.)))
if not os.path.exists('%s/Noise%d' % (self.output_path, self.sigma*255.)):
os.makedirs('%s/Noise%d'% (self.output_path, self.sigma*255.))
imageio.imsave('%s/Noise%d/%s.png'% (self.output_path, self.sigma*255., os.path.basename(img_path[:-4])), np.uint8(np.clip(np.round(out[0] * 255.), 0., 255.)))
if idx % 5 == 0:
print('[%d/%d] Processing' % ((idx+1), len(img_list)))
print('PSNR: %.4f' % np.mean(P))