Skip to content

Commit

Permalink
Fix bug whereby batches of size larger than 1000 would never get
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 242851617
  • Loading branch information
Marc G. Bellemare authored and psc-g committed Apr 11, 2019
1 parent f6dc339 commit 2435916
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions dopamine/replay_memory/circular_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@

# This constant determines how many iterations a checkpoint is kept for.
CHECKPOINT_DURATION = 4
MAX_SAMPLE_ATTEMPTS = 1000


def invalid_range(cursor, replay_capacity, stack_size, update_horizon):
Expand Down Expand Up @@ -103,7 +102,7 @@ def __init__(self,
batch_size,
update_horizon=1,
gamma=0.99,
max_sample_attempts=MAX_SAMPLE_ATTEMPTS,
max_sample_attempts=1000,
extra_storage_types=None,
observation_dtype=np.uint8,
action_shape=(),
Expand Down Expand Up @@ -438,10 +437,11 @@ def sample_index_batch(self, batch_size):
attempt_count = 0
while (len(indices) < batch_size and
attempt_count < self._max_sample_attempts):
attempt_count += 1
index = np.random.randint(min_id, max_id) % self._replay_capacity
if self.is_valid_transition(index):
indices.append(index)
else:
attempt_count += 1
if len(indices) != batch_size:
raise RuntimeError(
'Max sample attempts: Tried {} times but only sampled {}'
Expand Down Expand Up @@ -684,7 +684,7 @@ def __init__(self,
update_horizon=1,
gamma=0.99,
wrapped_memory=None,
max_sample_attempts=MAX_SAMPLE_ATTEMPTS,
max_sample_attempts=1000,
extra_storage_types=None,
observation_dtype=np.uint8,
action_shape=(),
Expand Down
4 changes: 2 additions & 2 deletions dopamine/replay_memory/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
batch_size,
update_horizon=1,
gamma=0.99,
max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
max_sample_attempts=1000,
extra_storage_types=None,
observation_dtype=np.uint8,
action_shape=(),
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(self,
batch_size=32,
update_horizon=1,
gamma=0.99,
max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
max_sample_attempts=1000,
extra_storage_types=None,
observation_dtype=np.uint8,
action_shape=(),
Expand Down

0 comments on commit 2435916

Please sign in to comment.