-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathreplay_memory.py
87 lines (68 loc) · 3.07 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
import random
import numpy as np
class ReplayMemory(object):
def __init__(self, img_shape, misc_len=0, capacity=10000, batch_size=32, dtype=np.float32):
self._s1_img = np.zeros([capacity] + list(img_shape), dtype=dtype)
self._s2_img = np.zeros([capacity] + list(img_shape), dtype=dtype)
self._a = np.zeros(capacity, dtype=np.int32)
self._r = np.zeros(capacity, dtype=dtype)
self._terminal = np.zeros(capacity, dtype=np.bool_)
self._s1_img_buf = np.zeros([batch_size] + list(img_shape), dtype=dtype)
self._s2_img_buf = np.zeros([batch_size] + list(img_shape), dtype=dtype)
self._a_buf = np.zeros(batch_size, dtype=dtype)
self._r_buf = np.zeros(batch_size, dtype=dtype)
self.terminal_buf = np.zeros(batch_size, dtype=np.bool_)
if misc_len > 0:
self._s1_misc = np.zeros((capacity, misc_len), dtype=dtype)
self._s2_misc = np.zeros((capacity, misc_len), dtype=dtype)
self._s1_misc_buf = np.zeros((batch_size, misc_len), dtype=dtype)
self._s2_misc_buf = np.zeros((batch_size, misc_len), dtype=dtype)
self._misc = True
else:
self._s1_misc = None
self._s2_misc = None
self._s1_misc_buf = None
self._s2_misc_buf = None
self._misc = False
self.capacity = capacity
self.size = 0
self._oldest_index = 0
self._batch_size = batch_size
ret = dict()
ret["s1_img"] = self._s1_img_buf
ret["s2_img"] = self._s2_img_buf
ret["a"] = self._a_buf
ret["r"] = self._r_buf
ret["terminal"] = self.terminal_buf
if misc_len > 0:
ret["s1_misc"] = self._s1_misc_buf
ret["s2_misc"] = self._s2_misc_buf
self._ret_dict = ret.copy()
def add_transition(self, s1, a, s2, r, terminal):
if self.size < self.capacity:
self.size += 1
self._s1_img[self._oldest_index] = s1[0]
if not terminal:
self._s2_img[self._oldest_index] = s2[0]
if self._misc:
self._s1_misc[self._oldest_index] = s1[1]
if not terminal:
self._s2_misc[self._oldest_index] = s2[1]
self._a[self._oldest_index] = a
self._r[self._oldest_index] = r
self._terminal[self._oldest_index] = terminal
self._oldest_index = (self._oldest_index + 1) % self.capacity
def get_sample(self):
if self._batch_size > self.size:
raise Exception("Replay memory doesn't contain " + str(self._batch_size) + " entries.")
indexes = random.sample(range(0, self.size), self._batch_size)
self._s1_img_buf[:] = self._s1_img[indexes]
self._s2_img_buf[:] = self._s2_img[indexes]
if self._misc:
self._s1_misc_buf[:] = self._s1_misc[indexes]
self._s2_misc_buf[:] = self._s2_misc[indexes]
self._a_buf[:] = self._a[indexes]
self._r_buf[:] = self._r[indexes]
self.terminal_buf[:] = self._terminal[indexes]
return self._ret_dict