Skip to content

Commit

Permalink
re-implement self._add for OutOfGraphPrioritizedReplayBuffer and …
Browse files Browse the repository at this point in the history
…assert during runtime that the args length in `_add` is equal to output of `self.get_add_args_signature`

PiperOrigin-RevId: 258364421
  • Loading branch information
btaba authored and psc-g committed Jul 22, 2019
1 parent f5f971f commit 85fa5c2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
35 changes: 28 additions & 7 deletions dopamine/replay_memory/circular_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,40 @@ def _add(self, *args):
Args:
*args: All the elements in a transition.
"""
cursor = self.cursor()
self._check_args_length(*args)
transition = {e.name: args[idx]
for idx, e in enumerate(self.get_add_args_signature())}
self._add_transition(transition)

def _add_transition(self, transition):
"""Internal add method to add transition dictionary to storage arrays.
arg_names = [e.name for e in self.get_add_args_signature()]
for arg_name, arg in zip(arg_names, args):
self._store[arg_name][cursor] = arg
Args:
transition: The dictionary of names and values of the transition
to add to the storage.
"""
cursor = self.cursor()
for arg_name in transition:
self._store[arg_name][cursor] = transition[arg_name]

self.add_count += 1
self.invalid_range = invalid_range(
self.cursor(), self._replay_capacity, self._stack_size,
self._update_horizon)

def _check_args_length(self, *args):
"""Check if args passed to the add method have the same length as storage.
Args:
*args: Args for elements used in storage.
Raises:
ValueError: If args have wrong length.
"""
if len(args) != len(self.get_add_args_signature()):
raise ValueError('Add expects {} elements, received {}'.format(
len(self.get_add_args_signature()), len(args)))

def _check_add_types(self, *args):
"""Checks if args passed to the add method match those of the storage.
Expand All @@ -285,9 +308,7 @@ def _check_add_types(self, *args):
Raises:
ValueError: If args have wrong shape or dtype.
"""
if len(args) != len(self.get_add_args_signature()):
raise ValueError('Add expects {} elements, received {}'.format(
len(self.get_add_args_signature()), len(args)))
self._check_args_length(*args)
for arg_element, store_element in zip(args, self.get_add_args_signature()):
if isinstance(arg_element, np.ndarray):
arg_shape = arg_element.shape
Expand Down
12 changes: 6 additions & 6 deletions dopamine/replay_memory/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,20 @@ def _add(self, *args):
Args:
*args: All the elements in a transition.
"""
self._check_args_length(*args)

# Use Schaul et al.'s (2015) scheme of setting the priority of new elements
# to the maximum priority so far.
parent_add_args = []
# Picks out 'priority' from arguments and passes the other arguments to the
# parent method.
# Picks out 'priority' from arguments and adds it to the sum_tree.
transition = {}
for i, element in enumerate(self.get_add_args_signature()):
if element.name == 'priority':
priority = args[i]
else:
parent_add_args.append(args[i])
transition[element.name] = args[i]

self.sum_tree.set(self.cursor(), priority)

super(OutOfGraphPrioritizedReplayBuffer, self)._add(*parent_add_args)
super(OutOfGraphPrioritizedReplayBuffer, self)._add_transition(transition)

def sample_index_batch(self, batch_size):
"""Returns a batch of valid indices sampled as in Schaul et al. (2015).
Expand Down
14 changes: 14 additions & 0 deletions tests/dopamine/replay_memory/prioritized_replay_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def add_blank(self, memory, action=0, reward=0.0, terminal=0, priority=1.0):
index = (memory.cursor() - 1) % REPLAY_CAPACITY
return index

def testAddWithAndWithoutPriority(self):
memory = self.create_default_memory()
self.assertEqual(memory.cursor(), 0)
zeros = np.zeros(SCREEN_SIZE)

self.add_blank(memory)
self.assertEqual(memory.cursor(), STACK_SIZE)
self.assertEqual(memory.add_count, STACK_SIZE)

# Check that the prioritized replay buffer expects an additional argument
# for priority.
with self.assertRaisesRegexp(ValueError, 'Add expects'):
memory.add(zeros, 0, 0, 0)

def testDummyScreensAddedToNewMemory(self):
memory = self.create_default_memory()
index = self.add_blank(memory)
Expand Down

0 comments on commit 85fa5c2

Please sign in to comment.