forked from qfettes/DeepRL-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patha2c_devel.py
159 lines (118 loc) · 4.79 KB
/
a2c_devel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gym
gym.logger.set_level(40)
import numpy as np
import torch
import torch.nn.functional as F
from IPython.display import clear_output
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from timeit import default_timer as timer
from datetime import timedelta
import os
import glob
from utils.wrappers import make_env_a2c_atari
from utils.plot import visdom_plot
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from utils.hyperparameters import Config
from agents.PPO import Model
use_vis=True
port=8097
log_dir = "/tmp/gym/"
try:
os.makedirs(log_dir)
except OSError:
files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
for f in files:
os.remove(f)
config = Config()
#ppo control
config.ppo_epoch = 3
config.num_mini_batch = 32
config.ppo_clip_param = 0.1
#a2c control
config.num_agents=8
config.rollout=128
config.USE_GAE = True
config.gae_tau = 0.95
#misc agent variables
config.GAMMA=0.99
config.LR=7e-4
config.entropy_loss_weight=0.01
config.value_loss_weight=1.0
config.grad_norm_max = 0.5
config.MAX_FRAMES=int(1e7 / config.num_agents / config.rollout)
if __name__=='__main__':
seed = 1
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_num_threads(1)
if use_vis:
from visdom import Visdom
viz = Visdom(port=port)
win = None
env_id = "PongNoFrameskip-v4"
envs = [make_env_a2c_atari(env_id, seed, i, log_dir) for i in range(config.num_agents)]
envs = SubprocVecEnv(envs) if config.num_agents > 1 else DummyVecEnv(envs)
obs_shape = envs.observation_space.shape
obs_shape = (obs_shape[0] * 4, *obs_shape[1:])
model = Model(env=envs, config=config)
current_obs = torch.zeros(config.num_agents, *obs_shape,
device=config.device, dtype=torch.float)
def update_current_obs(obs):
shape_dim0 = envs.observation_space.shape[0]
obs = torch.from_numpy(obs.astype(np.float32)).to(config.device)
current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
current_obs[:, -shape_dim0:] = obs
obs = envs.reset()
update_current_obs(obs)
model.rollouts.observations[0].copy_(current_obs)
episode_rewards = np.zeros(config.num_agents, dtype=np.float)
final_rewards = np.zeros(config.num_agents, dtype=np.float)
start=timer()
print_step = 1
print_threshold = 10
for frame_idx in range(1, config.MAX_FRAMES+1):
for step in range(config.rollout):
with torch.no_grad():
values, actions, action_log_prob = model.get_action(model.rollouts.observations[step])
cpu_actions = actions.view(-1).cpu().numpy()
obs, reward, done, _ = envs.step(cpu_actions)
episode_rewards += reward
masks = 1. - done.astype(np.float32)
final_rewards *= masks
final_rewards += (1. - masks) * episode_rewards
episode_rewards *= masks
rewards = torch.from_numpy(reward.astype(np.float32)).view(-1, 1).to(config.device)
masks = torch.from_numpy(masks).to(config.device).view(-1, 1)
current_obs *= masks.view(-1, 1, 1, 1)
update_current_obs(obs)
model.rollouts.insert(current_obs, actions.view(-1, 1), action_log_prob, values, rewards, masks)
with torch.no_grad():
next_value = model.get_values(model.rollouts.observations[-1])
model.rollouts.compute_returns(next_value, config.GAMMA)
value_loss, action_loss, dist_entropy = model.update(model.rollouts)
model.rollouts.after_update()
if frame_idx % print_threshold == 0:
end = timer()
total_num_steps = (frame_idx + 1) * config.num_agents * config.rollout
print("Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}".
format(frame_idx, total_num_steps,
int(total_num_steps / (end - start)),
np.mean(final_rewards),
np.median(final_rewards),
np.min(final_rewards),
np.max(final_rewards), dist_entropy,
value_loss, action_loss))
if use_vis and frame_idx % 100 == 0:
try:
# Sometimes monitor doesn't properly flush the outputs
win = visdom_plot(viz, win, log_dir, "PongNoFrameskip-v4",
'a2c-Q', config.MAX_FRAMES * config.num_agents * config.rollout)
except IOError:
pass
model.save_w()
envs.close()