You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@jupitermarketingagency OH MY GOOOOOOODDDD! Thank you for this, I've been fighting this stupid program for THREE DAYS! This guy really needs to revisit this code, I've had to debug basically all of it with a few rare exceptions. And this is only Chapter 4!!!!
You are a life saver, if you conjure up any more fixes please post--I will be internally grateful. :-)
@dkinneyBU Glad to hear that was of help to you. Yes, agree with you about him revisiting this code. So far from all the RL courses we've seen this happen over and over again because the books are more than 2 years old. So we've been trying to only focus on books that have been recently published.
#!/usr/bin/env python3
import gymnasium as gym
from collections import namedtuple
import numpy as np
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.optim as optim
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70
class Net(nn.Module):
def init(self, obs_size, hidden_size, n_actions):
super(Net, self).init()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
)
Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])
def iterate_batches(env, net, batch_size):
batch = []
episode_reward = 0.0
episode_steps = []
obs, _ = env.reset()
env.render()
sm = nn.Softmax(dim=1)
while True:
obs_v = torch.FloatTensor([obs])
act_probs_v = sm(net(obs_v))
act_probs = act_probs_v.data.numpy()[0]
action = np.random.choice(len(act_probs), p=act_probs)
next_obs, reward, is_done, _, _ = env.step(action)
episode_reward += reward
step = EpisodeStep(observation=obs, action=action)
episode_steps.append(step)
def filter_batch(batch, percentile):
rewards = list(map(lambda s: s.reward, batch))
reward_bound = np.percentile(rewards, percentile)
reward_mean = float(np.mean(rewards))
if name == "main":
env = gym.make("CartPole-v1", render_mode='human')
#env = gym.wrappers.Monitor(env, directory="mon", force=True)
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
The text was updated successfully, but these errors were encountered: