-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuffer.py
77 lines (61 loc) · 2.34 KB
/
buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
from typing import Dict, Any
class SemiMDPReplayBuffer:
def __init__(self, state_dim, limit):
self.state_mem = Memory(shape=(state_dim,), limit=limit)
self.action_mem = Memory(shape=(1,), limit=limit, dtype=np.int)
self.reward_mem = Memory(shape=(1,), limit=limit)
self.next_state_mem = Memory(shape=(state_dim,), limit=limit)
self.terminal_flag_mem = Memory(shape=(1,), limit=limit)
self.operating_time_mem = Memory(shape=(1,), limit=limit)
self.limit = limit
self.size = 0
def append(self, s, a, r, s_next, d, dt):
self.state_mem.append(s)
self.action_mem.append(a)
self.reward_mem.append(r)
self.next_state_mem.append(s_next)
self.terminal_flag_mem.append(d)
self.operating_time_mem.append(dt)
self.size = len(self.state_mem)
def sample_batch(self, batch_size: int) -> Dict[str, Any]:
rng = np.random.default_rng()
idxs = rng.choice(self.size, batch_size)
# get batch from each buffer
states = self.state_mem.get_batch(idxs)
actions = self.action_mem.get_batch(idxs)
rewards = self.reward_mem.get_batch(idxs)
next_states = self.next_state_mem.get_batch(idxs)
terminal_flags = self.terminal_flag_mem.get_batch(idxs)
dts = self.operating_time_mem.get_batch(idxs)
batch = {'state': states,
'action': actions,
'reward': rewards,
'next_state': next_states,
'done': terminal_flags,
'dt': dts
}
return batch
def __len__(self):
return len(self.state_mem)
class Memory:
"""
implementation of a circular buffer
"""
def __init__(self, shape, limit=1000000, dtype=np.float):
self.start = 0
self.data_shape = shape
self.size = 0
self.dtype = dtype
self.limit = limit
self.data = np.zeros((self.limit,) + shape)
def append(self, data):
if self.size < self.limit:
self.size += 1
else:
self.start = (self.start + 1) % self.limit
self.data[(self.start + self.size - 1) % self.limit] = data
def get_batch(self, idxs):
return self.data[(self.start + idxs) % self.limit]
def __len__(self):
return self.size