You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Given that forward() will return tuple: return out.view(batch_size, num_tau, self.num_actions), taus
Should we use .max(1) instead of .max(2) ?
Currently it is: Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N)
Maybe should be: Q_targets_next = Q_targets_next.detach().max(1)[0].unsqueeze(1) # (batch_size, 1, numActions)
In other words, to find the maximum in every tau group, rather than across every action?
Sorry if I misunderstood the process.
def learn(self, experiences):
"""Update value parameters using given batch of experience tuples.
Params
======
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
"""
self.optimizer.zero_grad()
states, actions, rewards, next_states, dones = experiences
# Get max predicted Q values (for next states) from target model
Q_targets_next, _ = self.qnetwork_target(next_states)
Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N) <-------------------------------------- HERE
# Compute Q targets for current states
Q_targets = rewards.unsqueeze(-1) + (self.GAMMA**self.n_step * Q_targets_next * (1. - dones.unsqueeze(-1)))
# Get expected Q values from local model
Q_expected, taus = self.qnetwork_local(states)
Q_expected = Q_expected.gather(2, actions.unsqueeze(-1).expand(self.BATCH_SIZE, 8, 1))
# Quantile Huber loss
td_error = Q_targets - Q_expected
assert td_error.shape == (self.BATCH_SIZE, 8, 8), "wrong td error shape"
huber_l = calculate_huber_loss(td_error, 1.0)
quantil_l = abs(taus -(td_error.detach() < 0).float()) * huber_l / 1.0
loss = quantil_l.sum(dim=1).mean(dim=1) # , keepdim=True if per weights get multipl
loss = loss.mean()
# Minimize the loss
loss.backward()
#clip_grad_norm_(self.qnetwork_local.parameters(),1)
self.optimizer.step()
# ------------------- update target network ------------------- #
self.soft_update(self.qnetwork_local, self.qnetwork_target)
return loss.detach().cpu().numpy()
```
The text was updated successfully, but these errors were encountered:
IgorAherne
changed the title
IQN-DQN.ipynb max over actions instead of max over taus?
IQN-DQN.ipynb max over taus instead of max over actions?
Apr 2, 2023
Hello,
Given that forward() will return tuple:
return out.view(batch_size, num_tau, self.num_actions), taus
Should we use .max(1) instead of .max(2) ?
Currently it is:
Q_targets_next = Q_targets_next.detach().max(2)[0].unsqueeze(1) # (batch_size, 1, N)
Maybe should be:
Q_targets_next = Q_targets_next.detach().max(1)[0].unsqueeze(1) # (batch_size, 1, numActions)
In other words, to find the maximum in every tau group, rather than across every action?
Sorry if I misunderstood the process.
The text was updated successfully, but these errors were encountered: