forked from guicho271828/latplan
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlatent_planner.py
104 lines (88 loc) · 2.76 KB
/
latent_planner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.image as mpimg
import os
import random
import copy
import keras
from keras.layers import Input, Dense, Lambda, Dropout, BatchNormalization, GaussianNoise
from keras.models import Model, Sequential
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.activations import softmax
from keras.objectives import binary_crossentropy as bce
from keras.objectives import mse
from scipy import misc
from sklearn.preprocessing import Binarizer
#img = misc.imread('/usr/share/datasets/KSCGR/hof/data1/boild-egg/hof256/0.jpg')
batch_size = 1000
latent_dim = 1764
M = 2
_N = 7
N = _N*_N
tau = K.variable(5.0, name="temperature")
def sampling(logits_y):
U = K.random_uniform(K.shape(logits_y), 0, 1)
y = logits_y - K.log(-K.log(U + 1e-20) + 1e-20) # logits + gumbel noise
y = softmax(K.reshape(y, (-1, N, M)) / tau)
y = K.reshape(y, (-1, N*M))
return y
def rgb2gray(rgb):
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
x = Input(shape=(latent_dim,))
_encoder = Sequential([
GaussianNoise(0.1, input_shape=(latent_dim,)),
Dense(4000, activation='relu'),
BatchNormalization(),
Dropout(0.4),
Dense(4000, activation='relu'),
BatchNormalization(),
Dropout(0.4),
Dense(M*N),
])
logits_y = _encoder(x)
z = Lambda(sampling, output_shape=(M*N,))(logits_y)
encoder = Model(x,z)
decoder = Sequential([
Dropout(0.4, input_shape=(N*M, )),
Dense(4000, activation='relu'),
BatchNormalization(),
Dropout(0.4),
Dense(4000, activation='relu'),
BatchNormalization(),
Dropout(0.4),
Dense(latent_dim, activation='sigmoid')
])
x_hat = decoder(z)
def gumbel_loss(x, x_hat):
q_y = K.reshape(logits_y, (-1, N, M))
q_y = softmax(q_y)
log_q_y = K.log(q_y + 1e-20)
kl_tmp = q_y * (log_q_y - K.log(1.0/M))
KL = K.sum(kl_tmp, axis=(1, 2))
elbo = latent_dim * bce(x, x_hat) - KL
# elbo = latent_dim * mse(x, x_hat) - KL
return elbo
vae = Model(x, x_hat)
vae.compile(optimizer='adam', loss=gumbel_loss)
#encoder = Model()
encoder.load_weights('model_new/encoder.h5')
decoder.load_weights('model_new/decoder.h5')
test = []
test.append(misc.imread("../eight-puzzle_mnist/102345678.jpg"))
test.append(misc.imread("../eight-puzzle_mnist/781023456.jpg"))
test = np.array(test)
test = test.astype('float32') / 255.
b = Binarizer(0.3)
for img in test:
b.transform(img, copy=False)
test = test.reshape((len(test), np.prod(test.shape[1:])))
code = encoder.predict(test)
code = code.astype(int)
code2 = decoder.predict(code)
print len(code[0])
plt.figure(figsize=(30, 5))
plt.imshow(code2[1].reshape(42, 42), cmap='gray')
plt.show()