-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdnc_downpour_sgd.py
162 lines (133 loc) · 4.5 KB
/
dnc_downpour_sgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 07:18:57 2017
@author: ryuhei
"""
import multiprocessing as mp
from queue import Empty, Full
import random
import matplotlib.pyplot as plt
import numpy as np
import chainer
from chainer import cuda
import chainer.functions as F
from copy_dataset import generate_copy_data
from dnc import Controller
def worker(batch_size, seq_len, dim_x, receive_queue, send_queue):
pid = mp.current_process()._identity[0]
np.random.seed(pid)
model = receive_queue.get()
if np.isscalar(seq_len):
seq_len = (seq_len, seq_len)
min_seq_len = min(seq_len)
max_seq_len = max(seq_len)
while True:
try:
model = receive_queue.get(block=False)
except Empty:
pass
random_seq_len = random.randint(min_seq_len, max_seq_len)
x, t = generate_copy_data(batch_size, random_seq_len, dim_x)
x = x.transpose((1, 0, 2))
t = t.transpose((1, 0, 2))
model.reset_state(batch_size)
for x_t in x:
model(x_t)
y = []
for t_t in t:
dummy_input = np.zeros_like(x_t)
y_t = model(dummy_input)
y.append(y_t)
y = F.stack(y)
loss = F.sigmoid_cross_entropy(y, t)
model.cleargrads()
loss.backward()
try:
send_queue.put(model, block=False)
except Full:
pass
if __name__ == '__main__':
train_batch_size = 1
train_seq_len = (5, 20)
test_batch_size = 40
test_seq_len = 20
dim_x = 9
dim_y = dim_x
dim_h = 100
num_memory_slots = 128
dim_memory_vector = 20
num_read_heads = 1
num_processes = 5
receive_queue_size = 5
num_updates = 1000000
learning_rate = 0.0001
model_master = Controller(dim_x, dim_y, dim_h, num_memory_slots,
dim_memory_vector, num_read_heads)
optimizer = chainer.optimizers.RMSprop(learning_rate)
optimizer.setup(model_master)
optimizer.zero_grads()
# Call the forward once in order to resolve uninitialized variables
x, c = generate_copy_data(1, 1, dim_x)
x = x.transpose((1, 0, 2))
model_master.reset_state(1)
model_master(x[0])
# Create worker processes
processes = []
send_queues = [] # Interface to send the master model to workers
receive_queue = mp.Queue(maxsize=receive_queue_size) # To receive models
for p in range(num_processes):
send_queue = mp.Queue(maxsize=1)
process = mp.Process(target=worker,
args=(train_batch_size, train_seq_len, dim_x,
send_queue, receive_queue))
process.start()
processes.append(process)
send_queue.put(model_master)
send_queues.append(send_queue)
count_updates = 0
evaluate_at = 0
acc_log = []
try:
while count_updates < num_updates:
try:
model = receive_queue.get(block=False)
model_master.cleargrads()
model_master.addgrads(model)
optimizer.update()
count_updates += 1
for p in np.random.permutation(num_processes):
send_queues[p].put(model_master, block=False)
except Empty:
pass
except Full:
pass
# Evaluation
if count_updates % 100 == 0 and evaluate_at != count_updates:
evaluate_at = count_updates
x, t = generate_copy_data(test_batch_size, test_seq_len, dim_x)
x = x.transpose((1, 0, 2))
t = t.transpose((1, 0, 2))
model_master.reset_state(test_batch_size)
for x_t in x:
model_master(x_t)
y = []
for t_t in t:
dummy_input = np.zeros_like(x_t)
y_t = model_master(dummy_input)
y.append(y_t)
y = F.stack(y)
loss = F.sigmoid_cross_entropy(y, t)
loss_data = cuda.to_cpu(loss.data)
acc = F.binary_accuracy(y, t)
acc_data = cuda.to_cpu(acc.data)
acc_log.append(acc_data)
print('{}: {:0.5},\t{:1.5}'.format(
count_updates, float(acc_data), float(loss_data)))
except KeyboardInterrupt:
print('Ctrl+c')
for process in processes:
process.terminate()
plt.plot(acc_log)
plt.grid()
plt.ylim([0, 1])
plt.show()