diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..80dd262 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +log/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d5b8fc6 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "commons"] + path = commons + url = git@github.com:hiwonjoon/tf-boilerplate.git +[submodule "libs/flann"] + path = libs/flann + url = git@github.com:hiwonjoon/flann diff --git a/README.md b/README.md new file mode 100644 index 0000000..603e9e3 --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +# Neural Episodic Control + +Tensorflow implementation of Neural Episodic Control + diff --git a/commons b/commons new file mode 160000 index 0000000..6125939 --- /dev/null +++ b/commons @@ -0,0 +1 @@ +Subproject commit 612593975705c9a918fd28cfe0698a776d18bd97 diff --git a/fast_dictionary.py b/fast_dictionary.py new file mode 100644 index 0000000..2a51072 --- /dev/null +++ b/fast_dictionary.py @@ -0,0 +1,122 @@ +import collections +import os +import pickle +import numpy as np +from pyflann import FLANN +# ngtpy is buggy. (incremental remove and add is fragile) +#import ngtpy + +class FastDictionary(object): + def __init__(self,maxlen): + self.flann = FLANN() + + self.counter = 0 + + self.contents_lookup = {} #{oid: (e,q)} + self.p_queue = collections.deque() #priority queue contains; list of (priotiry_value,oid) + self.maxlen = maxlen + + def save(self,dir,fname,it=None): + fname = f'{fname}' if it is None else f'{fname}-{it}' + + with open(os.path.join(dir,fname),'wb') as f: + pickle.dump((self.contents_lookup,self.p_queue,self.maxlen),f) + + def restore(self,fname): + with open(fname,'rb') as f: + _contents_lookup, _p_queue, maxlen = pickle.load(f) + + assert self.maxlen == maxlen, (self.maxlen,maxlen) + + new_oid_lookup = {} + E = [] + for oid,(e,q) in _contents_lookup.items(): + E.append(e) + + new_oid, self.counter = self.counter, self.counter+1 + + new_oid_lookup[oid] = new_oid + self.contents_lookup[new_oid] = (e,q) + + # Rebuild KD-Tree + self.flann.build_index(np.array(E)) + + # Rebuild Heap + while len(_p_queue) >= 0: + oid = _p_queue.popleft() + + if not oid in new_oid_lookup: + continue + self.p_queue.append(new_oid_lookup[oid]) + + def add(self,E,Contents): + assert not np.isnan(E).any(), ('NaN Detected in Add',np.argwhere(np.isnan(E))) + assert len(E) == len(Contents) + + if self.counter == 0: + self.flann.build_index(E) + else: + self.flann.add_points(E) + Oid, self.counter = np.arange(self.counter,self.counter+len(E)), self.counter + len(E) + + for oid,content in zip(Oid,Contents): + self.contents_lookup[oid] = content + self.p_queue.append(oid) + + if len(self.contents_lookup) > self.maxlen: + while not self.p_queue[0] in self.contents_lookup: + self.p_queue.popleft() #invalidated items due to update, so just pop. + + old_oid = self.p_queue.popleft() + + self.flann.remove_point(old_oid) + del self.contents_lookup[old_oid] + + def query_knn(self,E,K=100): + assert not np.isnan(E).any(), ('NaN Detected in Querying',np.argwhere(np.isnan(E))) + + flatten = False + if E.ndim == 1: + E = E[None] + flatten = True + + Oids, _ = self.flann.nn_index(E,num_neighbors=K) + NN_E = np.zeros((len(E),K,E.shape[1]),np.float32) + NN_Q = np.zeros((len(E),K),np.float32) + + for b,oids in enumerate(Oids): + for k,oid in enumerate(oids): + e,q = self.contents_lookup[oid] + + NN_E[b,k] = e + NN_Q[b,k] = q + + if flatten: + return Oids, NN_E[0], NN_Q[0] + else: + return Oids, NN_E, NN_Q + + def update(self,Oid,E,Contents): + """ + Basically, same this is remove & add. + This code only manages a heap more effectively; since delete an item in the middle of heap is not trivial!) + """ + assert not np.isnan(E).any(), ('NaN Detected in Updating',np.argwhere(np.isnan(E))) + assert len(np.unique(Oid)) == len(Oid) + + # add new Embeddings + self.flann.add_points(E) + NewOid, self.counter = np.arange(self.counter,self.counter+len(E)), self.counter + len(E) + + for oid,new_oid,content in zip(Oid,NewOid,Contents): + self.contents_lookup[new_oid] = content + self.p_queue.append(new_oid) + + # delete from kd-tree + self.flann.remove_point(oid) + # delete from contents_lookup + del self.contents_lookup[oid] + # I cannot remove from p_queue, but it will be handeled in add op. + +if __name__ == "__main__": + pass diff --git a/libs/atari_wrappers.py b/libs/atari_wrappers.py new file mode 100644 index 0000000..f715c85 --- /dev/null +++ b/libs/atari_wrappers.py @@ -0,0 +1,249 @@ +# Code from openai/baselines +# https://raw.githubusercontent.com/openai/baselines/master/baselines/common/atari_wrappers.py + +import numpy as np +import os +os.environ.setdefault('PATH', '') +from collections import deque +import gym +from gym import spaces +import cv2 +cv2.ocl.setUseOpenCL(False) + +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + self.noop_action = 0 + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + + def reset(self, **kwargs): + """ Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset(**kwargs) + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, _ = self.env.step(self.noop_action) + if done: + obs = self.env.reset(**kwargs) + return obs + + def step(self, ac): + return self.env.step(ac) + +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self, **kwargs): + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset(**kwargs) + return obs + + def step(self, ac): + return self.env.step(ac) + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. since it helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + self.was_real_done = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, + # then update lives to handle bonus lives + lives = self.env.unwrapped.ale.lives() + if lives < self.lives and lives > 0: + # for Qbert sometimes we stay in lives == 0 condition for a few frames + # so it's important to keep lives > 0, so that we only reset once + # the environment advertises done. + done = True + self.lives = lives + return obs, reward, done, info + + def reset(self, **kwargs): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset(**kwargs) + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _ = self.env.step(0) + self.lives = self.env.unwrapped.ale.lives() + return obs + +class MaxAndSkipEnv(gym.Wrapper): + def __init__(self, env, skip=4): + """Return only every `skip`-th frame""" + gym.Wrapper.__init__(self, env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) + self._skip = skip + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + done = None + for i in range(self._skip): + obs, reward, done, info = self.env.step(action) + if i == self._skip - 2: self._obs_buffer[0] = obs + if i == self._skip - 1: self._obs_buffer[1] = obs + total_reward += reward + if done: + break + # Note that the observation on the done=True frame + # doesn't matter + max_frame = self._obs_buffer.max(axis=0) + + return max_frame, total_reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + +class ClipRewardEnv(gym.RewardWrapper): + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return np.sign(reward) + +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env, width=84, height=84, grayscale=True): + """Warp frames to 84x84 as done in the Nature paper and later work.""" + gym.ObservationWrapper.__init__(self, env) + self.width = width + self.height = height + self.grayscale = grayscale + if self.grayscale: + self.observation_space = spaces.Box(low=0, high=255, + shape=(self.height, self.width, 1), dtype=np.uint8) + else: + self.observation_space = spaces.Box(low=0, high=255, + shape=(self.height, self.width, 3), dtype=np.uint8) + + def observation(self, frame): + if self.grayscale: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) + if self.grayscale: + frame = np.expand_dims(frame, -1) + return frame + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Stack k last frames. + + Returns lazy array, which is much more memory efficient. + + See Also + -------- + baselines.common.atari_wrappers.LazyFrames + """ + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) + + def reset(self): + ob = self.env.reset() + for _ in range(self.k): + self.frames.append(ob) + return self._get_ob() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self._get_ob(), reward, done, info + + def _get_ob(self): + assert len(self.frames) == self.k + return LazyFrames(list(self.frames)) + +class ScaledFloatFrame(gym.ObservationWrapper): + def __init__(self, env): + gym.ObservationWrapper.__init__(self, env) + self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) + + def observation(self, observation): + # careful! This undoes the memory optimization, use + # with smaller replay buffers only. + return np.array(observation).astype(np.float32) / 255.0 + +class LazyFrames(object): + def __init__(self, frames): + """This object ensures that common frames between the observations are only stored once. + It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay + buffers. + + This object should only be converted to numpy array before being passed to the model. + + You'd not believe how complex the previous solution was.""" + self._frames = frames + self._out = None + + def _force(self): + if self._out is None: + self._out = np.concatenate(self._frames, axis=-1) + self._frames = None + return self._out + + def __array__(self, dtype=None): + out = self._force() + if dtype is not None: + out = out.astype(dtype) + return out + + def __len__(self): + return len(self._force()) + + def __getitem__(self, i): + return self._force()[i] + +def make_atari(env_id): + env = gym.make(env_id) + assert 'NoFrameskip' in env.spec.id + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + return env + +def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): + """Configure environment for DeepMind-style Atari. + """ + if episode_life: + env = EpisodicLifeEnv(env) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = WarpFrame(env) + if scale: + env = ScaledFloatFrame(env) + if clip_rewards: + env = ClipRewardEnv(env) + if frame_stack: + env = FrameStack(env, 4) + return env + diff --git a/libs/flann b/libs/flann new file mode 160000 index 0000000..f91027d --- /dev/null +++ b/libs/flann @@ -0,0 +1 @@ +Subproject commit f91027d3060f073063ffc2edfbb876aa87617a3b diff --git a/q_learning.py b/q_learning.py new file mode 100644 index 0000000..62fe514 --- /dev/null +++ b/q_learning.py @@ -0,0 +1,351 @@ +import os +import itertools +import argparse +from pathlib import Path +import numpy as np +import tensorflow as tf +import gym +from tqdm import tqdm + +from commons.ops import * +from fast_dictionary import FastDictionary +from replay_buffer import ReplayBuffer +from libs.atari_wrappers import make_atari, wrap_deepmind + +FRAME_STACK = 4 + +def _build(net,x): + for block in net: x = block(x) + return x + +class NEC(object): + def __init__(self, + num_ac, + memory_max_len, + K, + embed_len, + delta=1e-3, + q_lr=1e-2, + ): + self.delta = delta + self.q_lr = q_lr + self.K = K + self.embed_len = embed_len + + self.Qa = [FastDictionary(memory_max_len) for _ in range(num_ac)] + + # Experience reaclled from a replay buffer through FastDictionary + self.s = tf.placeholder(tf.float32,[None,84,84,FRAME_STACK]) #[B,seq_len,state_dims] + + self.nn_es = tf.placeholder(tf.float32,[None,None,embed_len]) #[B,K,embed_len] + self.nn_qs = tf.placeholder(tf.float32,[None,None]) + self.target_q = tf.placeholder(tf.float32,[None]) + + with tf.variable_scope('weights') as param_scope: + self.param_scope = param_scope + + self.net = [ + Conv2d('c1',4,32,k_h=8,k_w=8,d_h=4,d_w=4,padding='VALID',data_format='NHWC'), + lambda x: tf.nn.relu(x), + Conv2d('c2',32,64,k_h=4,k_w=4,d_h=2,d_w=2,padding='VALID',data_format='NHWC'), + lambda x: tf.nn.relu(x), + Conv2d('c3',64,64,k_h=3,k_w=3,d_h=1,d_w=1,padding='VALID',data_format='NHWC'), + lambda x: tf.nn.relu(x), + Linear('fc1',3136,512), + lambda x: tf.nn.relu(x), + Linear('fc2',512,embed_len) + ] + + self.embed = _build(self.net,self.s) + + dists = tf.reduce_sum((self.nn_es - self.embed[:,None])**2,axis=2) + kernel = 1 / (dists + self.delta) + + #kernel = tf.Print(kernel,[tf.shape(self.nn_es),tf.shape(dists),tf.shape(kernel)]) + + q = tf.reduce_sum(kernel * self.nn_qs, axis=1) / tf.reduce_sum(kernel, axis=1, keep_dims=True) + self.loss = tf.reduce_mean((self.target_q - q)**2) + tf.summary.scalar('loss',self.loss,collections=['summaries']) + + + # Optimize op + #self.optim = tf.train.AdamOptimizer(1e-4) + self.optim = tf.train.RMSPropOptimizer(1e-4) + self.update_op = self.optim.minimize(self.loss,var_list=self.parameters(train=True)) + + self.nn_es_gradient, self.nn_qs_gradient = tf.gradients(self.loss, [self.nn_es,self.nn_qs]) + self.new_nn_es = self.nn_es - self.q_lr * self.nn_es_gradient + self.new_nn_qs = self.nn_qs - self.q_lr * self.nn_qs_gradient + + self.summaries = tf.summary.merge_all(key='summaries') + self.saver = tf.train.Saver(var_list=self.parameters(train=True),max_to_keep=0) + + def parameters(self,train=False): + if train: + return tf.trainable_variables(self.param_scope.name) + else: + return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,self.param_scope.name) + + def save(self,dir,it=None): + sess = tf.get_default_session() + self.saver.save(sess,dir+'/model.ckpt',global_step=it,write_meta_graph=False) + + for i,Q in enumerate(self.Qa): + Q.save(dir,f'Q-{i}.pkl',it) + + def restore(self,model_file): + model_dir,file_name = os.path.split(model_file) + it = None if not '-' in file_name else int(file_name.split('-')[-1]) + + sess = tf.get_default_session() + + self.saver.restore(sess,model_file) + + for i,Q in enumerate(self.Qa): + memory_name = f'Q-{i}.pkl' if it is None else f'Q-{i}.pkl-{it}' + + Q.restore(os.path.join(model_dir,memory_name)) + + # NEC Impl + def _embed(self,b_s,max_batch_size=1024): + sess = tf.get_default_session() + + b_e = [] + for i in range(0,len(b_s),max_batch_size): + b_e.append( + sess.run(self.embed,feed_dict={self.s:b_s[i:i+max_batch_size]})) + return np.concatenate(b_e,axis=0) + + def _read_table(self,e,Q,K): + oids, nn_es, nn_qs = Q.query_knn(e,K=K) + + dists = np.linalg.norm(nn_es - e,axis=1)**2 + kernel = 1 / (dists + self.delta) + + q = np.sum(kernel * nn_qs) / np.sum(kernel) + + return oids, q + + def policy(self,s): + e = self._embed(s[None])[0] + + qs = [self._read_table(e,Q,K=self.K)[1] for Q in self.Qa] + + ac = np.argmax(qs) + return ac, (e,qs[ac]) + + def append(self,e,a,q): + sess = tf.get_default_session() + + self.Qa[a].add(e[None],[(e,q)]) + + def update(self,b_s,b_a,b_q): + sess = tf.get_default_session() + + b_e = self._embed(b_s) + + b_nn_es = np.zeros((len(b_s),self.K,self.embed_len),np.float32) + b_nn_qs = np.zeros((len(b_s),self.K),np.float32) + + for i,Q in enumerate(self.Qa): + idxes = np.where(b_a==i) + + Oids, nn_Es, nn_Qs = Q.query_knn(b_e[idxes],K=self.K) + + b_nn_es[idxes] = nn_Es + b_nn_qs[idxes] = nn_Qs + + # Update the table (embedding & q itself) + new_Es, new_Qs = \ + sess.run([self.new_nn_es,self.new_nn_qs],feed_dict={ + self.s:b_s[idxes], + self.nn_es:nn_Es, + self.nn_qs:nn_Qs, + self.target_q:b_q[idxes], + }) + + oids = np.reshape(Oids,[-1]) + nn_es = np.reshape(nn_Es,[-1,self.embed_len]) + nn_qs = np.reshape(nn_Qs,[-1]) + + _, unique_idxes = np.unique(oids,return_index=True) + Q.update(oids[unique_idxes], + nn_es[unique_idxes], + list(zip(nn_es[unique_idxes],nn_qs[unique_idxes]))) + + # Update the embedding network + loss, _, summary_str = sess.run([self.loss,self.update_op,self.summaries],feed_dict={ + self.s:b_s, + self.nn_es:b_nn_es, + self.nn_qs:b_nn_qs, + self.target_q:b_q, + }) + + return loss, summary_str + + +def train( + args, + log_dir, + seed, + env_id, + replay_buffer_len, + memory_len, + p, # #nn items; reported number is 50 + embed_size, # embedding vector length; reported number is ? + gamma, # discount value; reported number is 0.99 + N, # N-step bootstrapping; reported number is 100 + update_period, # the reported number is 16 + batch_size, # the reported number is 32 + init_eps, + **kwargs +): + # another hyper params + _gw = np.array([gamma**i for i in range(N)]) + epsilon = 1.0 + min_epsilon = 0.01 + epsilon_decay = 0.99 #explonential decaying factor + + # expr setting + Path(log_dir).mkdir(parents=True,exist_ok='temp' in log_dir) + + with open(os.path.join(log_dir,'args.txt'),'w') as f: + f.write( str(args) ) + + np.random.seed(seed) + tf.random.set_random_seed(seed) + + # Env + env = wrap_deepmind(make_atari(env_id), + episode_life=False, + clip_rewards=False, + frame_stack=True, + scale=True) + num_ac = env.action_space.n + + # ReplayBuffer + replay_buffer = ReplayBuffer(replay_buffer_len) + + # Neural Episodic Controller + nec = NEC(num_ac,memory_len,p,embed_size) + + sess = tf.InteractiveSession() + sess.run(tf.global_variables_initializer()) + + summary_writer = tf.summary.FileWriter(os.path.join(log_dir,'tensorboard')) + + ####### Setup Done + + num_frames = 0 # #frames observed + + # Fill up the memory and replay buffer with a random policy + for _ in range(init_eps): + ob = env.reset() + + obs,acs,rewards = [ob],[],[] + for _ in itertools.count(): + ac = np.random.randint(num_ac) + + ob,r,done,_ = env.step(ac) + + obs.append(ob) + acs.append(ac) + rewards.append(r) + + num_frames += 1 + + if done: + break + + Rs = [np.sum(_gw[:len(rewards[i:i+N])]*rewards[i:i+N]) for i in range(len(rewards))] + + obs = np.array(obs) + es = nec._embed(obs) + + for ob,e,a,R in zip(obs,es,acs,Rs): + nec.append(e, a, R) + + replay_buffer.append(ob,a,R) + + # Training! + try: + for ep in itertools.count(): + ob = env.reset() + + obs,acs,rewards,es,Vs = [ob],[],[],[],[] + for t in itertools.count(): + # Epsilon Greedy Policy + ac, (e,V) = nec.policy(ob) + if np.random.random() < epsilon: + ac = np.random.randint(num_ac) + + ob,r,done,_ = env.step(ac) + + obs.append(ob) + acs.append(ac) + rewards.append(r) + es.append(e) + Vs.append(V) + + num_frames += 1 + + # Train on random minibatch from replacy buffer + if num_frames % update_period == 0: + b_s,b_a,b_R = replay_buffer.sample(batch_size) + loss, summary = nec.update(b_s,b_a,b_R) + print(f'[{num_frames}] loss: {loss}') + + if t >= N: + # N-Step Bootstrapping + # TODO: implement the efficient version + R = np.sum(_gw * rewards[t-N:t]) + gamma**N*Vs[t] #R_{t-N} + + # append to memory + nec.append(es[t-N], acs[t-N], R) + + # append to replay buffer + replay_buffer.append(obs[t-N], acs[t-N], R) + + if done: + break + + print(f'Episode {ep} -- Ep Len: {len(obs)} Acc Reward: {np.sum(rewards)}') + + # Remaining items which is not bootstrappable; partial trajectory close to end of episode + # Append to memory & replay buffer + for t in range(len(rewards)-N,len(rewards)): + R = np.sum([gamma**(i-t)*rewards[i] for i in range(t,len(rewards))]) + nec.append(es[t], acs[t], R) + replay_buffer.append(obs[t], acs[t], R) + + # epsilon decay + epsilon = max(min_epsilon, epsilon * epsilon_decay) + + except KeyboardInterrupt: + nec.save(log_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=None) + # expr setting + parser.add_argument('--log_dir',required=True) + parser.add_argument('--seed',type=int,default=0) + parser.add_argument('--mode',default='train',choices=['train']) + # Env + parser.add_argument('--env_id', default='PongNoFrameskip-v4', help='Select the domain name; eg) cartpole') + parser.add_argument('--gamma',type=float,default=0.99) + # network + parser.add_argument('--embed_size',type=int,default=64) + parser.add_argument('--memory_len',type=int,default=int(5*1e5)) + parser.add_argument('--replay_buffer_len',type=int,default=int(1e5)) + parser.add_argument('--p',type=int,default=50) + # Training + parser.add_argument('--init_eps',type=int,default=10,help='# episodes with random policy for initialize memory and replay buffer') + parser.add_argument('--N',type=int,default=100,help='N-step-bootstrapping') + parser.add_argument('--update_period',type=int,default=16) + parser.add_argument('--batch_size',type=int,default=32) + + args = parser.parse_args() + if args.mode == 'train': + train(args=args,**vars(args)) + + diff --git a/replay_buffer.py b/replay_buffer.py new file mode 100644 index 0000000..cb6805d --- /dev/null +++ b/replay_buffer.py @@ -0,0 +1,24 @@ +import os +import numpy as np + +class ReplayBuffer(object): + def __init__(self,maxlen): + self.storage = [] + + self.counter = 0 + self.maxlen = maxlen + + def append(self,s,a,R): + ticket, self.counter = self.counter, self.counter + 1 + if ticket < self.maxlen: + self.storage.append((s,a,R)) + else: + self.storage[ticket%self.maxlen] = (s,a,R) + + def sample(self,size): + idxes = np.random.randint(self.counter,size=size) % self.maxlen + + b_s,b_a,b_R = zip(*[self.storage[i] for i in idxes]) + b_s,b_a,b_R = np.array(b_s), np.array(b_a), np.array(b_R) + + return b_s,b_a,b_R