-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuffers.py
97 lines (76 loc) · 3.79 KB
/
buffers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import random
import pickle
import numpy as np
from collections import deque, defaultdict
class MultiEnvReplayBuffer:
"""
Generic struct :
s, act, reward,
dict[key for each env: int] =
---------- deque(store tuple) = each tuple = 1 transition
------------------------------- = (transition) = (state: np.array,
action: np.array,
rewards: np.float,
next_state: np.array,
dones: bool)
--------------------------------------------------------------
for save a local copy of the buffer :
automatically save each deque, in several file.
1 file for env
"""
def __init__(self, buffer_size_per_env, **kwargs):
self.size_per_env = buffer_size_per_env
self.buffer_size_per_env = buffer_size_per_env
self._initialize_empty_buffer()
self.elem_for_buffer = defaultdict(lambda: 0)
if kwargs.get('preload'): # get handle keyError, basic dict[key] dont
self.read_buffers(kwargs['path_preload'])
def add(self, state, action, reward, next_state, done, env_id):
transition = (state, action, reward, next_state, done)
if self.elem_for_buffer[env_id] < self.buffer_size_per_env:
self.elem_for_buffer[env_id] += 1
self.buffers[env_id].append(transition)
def sample_env(self, env_id, batch_size):
transitions = random.sample(self.buffers[env_id], batch_size)
return self._transpose_transitions(transitions)
def sample_all_envs(self, batch_size):
batch = []
for i in range(batch_size-1):
selected_buffer_idx = random.choice(list(self.buffers.keys()))
elem_idx = random.randint(0, self.elem_for_buffer[selected_buffer_idx]-1)
batch.append(self.buffers[selected_buffer_idx][elem_idx])
return self._transpose_transitions(batch)
def _transpose_transitions(self, transitions):
states = np.array([t[0] for t in transitions], dtype=np.float64)
actions = np.array([t[1] for t in transitions], dtype=np.float64)
rewards = np.array([t[2] for t in transitions], dtype=np.float64).reshape((-1, 1))
next_states = np.array([t[3] for t in transitions], dtype=np.float64)
dones = np.array([t[4] for t in transitions])
return states, actions, rewards, next_states, dones
def _initialize_empty_buffer(self):
self.buffers = defaultdict(lambda: deque(maxlen=self.buffer_size_per_env))
def write_buffer(self, path):
#convert default dict to basic dict : easier to save
for env_id in self.buffers.keys():
#just check if id is a single digit
str_env_id = '0'+str(env_id) if (env_id < 10) else str(env_id)
file_path = os.path.join(path, f'buffer_env{str_env_id}.pkl')
with open(file_path, 'wb') as file:
#sto salvando solo le deque
pickle.dump(self.buffers[env_id], file)
def read_buffers(self, path, from_scratch=False):
if from_scratch:
self._initialize_empty_buffer()
files = os.listdir(path)
print(f'old buffer to load: name: {files}')
buff_key = []
for file in files:
if file.endswith('pkl'): buff_key.append(file[-6:-4])
for k in buff_key:
buffer_path = os.path.join(path, f'buffer_env{k}.pkl')
print(f'processing buffer path : {buffer_path}')
with open(buffer_path, 'rb') as local_buff:
temp_buff = pickle.load(local_buff)
self.buffers[int(k)] += temp_buff
self.elem_for_buffer[int(k)] = len(self.buffers[int(k)])