diff --git a/dopamine/discrete_domains/atari_lib.py b/dopamine/discrete_domains/atari_lib.py index fff476e1..6d0467e2 100644 --- a/dopamine/discrete_domains/atari_lib.py +++ b/dopamine/discrete_domains/atari_lib.py @@ -51,8 +51,14 @@ class has two main functions: `.__init__` and `.call`. When we create our import tensorflow.compat.v1 as tf import cv2 -from tensorflow.contrib import layers as contrib_layers -from tensorflow.contrib import slim as contrib_slim +from tensorflow.compat.v1 import layers as contrib_layers + +# Allow failure on this import (not in tf2). This means atari won't be +# available but other domains will. +try: + from tensorflow.contrib import slim as contrib_slim +except: + pass NATURE_DQN_OBSERVATION_SHAPE = (84, 84) # Size of downscaled Atari 2600 frame. diff --git a/dopamine/replay_memory/circular_replay_buffer.py b/dopamine/replay_memory/circular_replay_buffer.py index 1a5020fa..096fffd7 100644 --- a/dopamine/replay_memory/circular_replay_buffer.py +++ b/dopamine/replay_memory/circular_replay_buffer.py @@ -34,7 +34,7 @@ import tensorflow.compat.v1 as tf import gin.tf -from tensorflow.contrib import staging as contrib_staging +from tensorflow.python.ops import data_flow_ops # Defines a type describing part of the tuple returned by the replay # memory. Each element of the tuple is a tensor of shape [batch, ...] where @@ -855,7 +855,7 @@ def _set_up_staging(self, transition): transition_type = self.memory.get_transition_elements() # Create the staging area in CPU. - prefetch_area = contrib_staging.StagingArea( + prefetch_area = data_flow_ops.StagingArea( [shape_with_type.type for shape_with_type in transition_type]) # Store prefetch op for tests, but keep it private -- users should not be