diff --git a/wrappers.py b/wrappers.py index 6862705..41503ce 100644 --- a/wrappers.py +++ b/wrappers.py @@ -176,6 +176,7 @@ def step(self, action): def reset(self): obs = self._env.reset() + obs = {k: self._convert(v) for k, v in obs.items()} transition = obs.copy() transition['action'] = np.zeros(self._env.action_space.shape) transition['reward'] = 0.0 @@ -343,7 +344,7 @@ def __getattr__(self, name): def observation_space(self): spaces = self._env.observation_space.spaces assert 'reward' not in spaces - spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32) return gym.spaces.Dict(spaces) def step(self, action):