diff --git a/dopamine/replay_memory/circular_replay_buffer.py b/dopamine/replay_memory/circular_replay_buffer.py index 7170434b..a992b887 100644 --- a/dopamine/replay_memory/circular_replay_buffer.py +++ b/dopamine/replay_memory/circular_replay_buffer.py @@ -265,17 +265,40 @@ def _add(self, *args): Args: *args: All the elements in a transition. """ - cursor = self.cursor() + self._check_args_length(*args) + transition = {e.name: args[idx] + for idx, e in enumerate(self.get_add_args_signature())} + self._add_transition(transition) + + def _add_transition(self, transition): + """Internal add method to add transition dictionary to storage arrays. - arg_names = [e.name for e in self.get_add_args_signature()] - for arg_name, arg in zip(arg_names, args): - self._store[arg_name][cursor] = arg + Args: + transition: The dictionary of names and values of the transition + to add to the storage. + """ + cursor = self.cursor() + for arg_name in transition: + self._store[arg_name][cursor] = transition[arg_name] self.add_count += 1 self.invalid_range = invalid_range( self.cursor(), self._replay_capacity, self._stack_size, self._update_horizon) + def _check_args_length(self, *args): + """Check if args passed to the add method have the same length as storage. + + Args: + *args: Args for elements used in storage. + + Raises: + ValueError: If args have wrong length. + """ + if len(args) != len(self.get_add_args_signature()): + raise ValueError('Add expects {} elements, received {}'.format( + len(self.get_add_args_signature()), len(args))) + def _check_add_types(self, *args): """Checks if args passed to the add method match those of the storage. @@ -285,9 +308,7 @@ def _check_add_types(self, *args): Raises: ValueError: If args have wrong shape or dtype. """ - if len(args) != len(self.get_add_args_signature()): - raise ValueError('Add expects {} elements, received {}'.format( - len(self.get_add_args_signature()), len(args))) + self._check_args_length(*args) for arg_element, store_element in zip(args, self.get_add_args_signature()): if isinstance(arg_element, np.ndarray): arg_shape = arg_element.shape diff --git a/dopamine/replay_memory/prioritized_replay_buffer.py b/dopamine/replay_memory/prioritized_replay_buffer.py index 9ed778ba..bfd094c1 100644 --- a/dopamine/replay_memory/prioritized_replay_buffer.py +++ b/dopamine/replay_memory/prioritized_replay_buffer.py @@ -124,20 +124,20 @@ def _add(self, *args): Args: *args: All the elements in a transition. """ + self._check_args_length(*args) + # Use Schaul et al.'s (2015) scheme of setting the priority of new elements # to the maximum priority so far. - parent_add_args = [] - # Picks out 'priority' from arguments and passes the other arguments to the - # parent method. + # Picks out 'priority' from arguments and adds it to the sum_tree. + transition = {} for i, element in enumerate(self.get_add_args_signature()): if element.name == 'priority': priority = args[i] else: - parent_add_args.append(args[i]) + transition[element.name] = args[i] self.sum_tree.set(self.cursor(), priority) - - super(OutOfGraphPrioritizedReplayBuffer, self)._add(*parent_add_args) + super(OutOfGraphPrioritizedReplayBuffer, self)._add_transition(transition) def sample_index_batch(self, batch_size): """Returns a batch of valid indices sampled as in Schaul et al. (2015). diff --git a/tests/dopamine/replay_memory/prioritized_replay_buffer_test.py b/tests/dopamine/replay_memory/prioritized_replay_buffer_test.py index db6c0e59..6049edb8 100644 --- a/tests/dopamine/replay_memory/prioritized_replay_buffer_test.py +++ b/tests/dopamine/replay_memory/prioritized_replay_buffer_test.py @@ -60,6 +60,20 @@ def add_blank(self, memory, action=0, reward=0.0, terminal=0, priority=1.0): index = (memory.cursor() - 1) % REPLAY_CAPACITY return index + def testAddWithAndWithoutPriority(self): + memory = self.create_default_memory() + self.assertEqual(memory.cursor(), 0) + zeros = np.zeros(SCREEN_SIZE) + + self.add_blank(memory) + self.assertEqual(memory.cursor(), STACK_SIZE) + self.assertEqual(memory.add_count, STACK_SIZE) + + # Check that the prioritized replay buffer expects an additional argument + # for priority. + with self.assertRaisesRegexp(ValueError, 'Add expects'): + memory.add(zeros, 0, 0, 0) + def testDummyScreensAddedToNewMemory(self): memory = self.create_default_memory() index = self.add_blank(memory)