diff --git a/all/core/state.py b/all/core/state.py index d556d312..29e785ae 100644 --- a/all/core/state.py +++ b/all/core/state.py @@ -33,19 +33,19 @@ class State(dict): The torch device on which component tensors are stored. """ - def __init__(self, x, device='cpu', **kwargs): + def __init__(self, x, device="cpu", **kwargs): if not isinstance(x, dict): - x = {'observation': x} + x = {"observation": x} for k, v in kwargs.items(): x[k] = v - if 'observation' not in x: - raise Exception('State must contain an observation') - if 'reward' not in x: - x['reward'] = 0. - if 'done' not in x: - x['done'] = False - if 'mask' not in x: - x['mask'] = 1. - x['done'] + if "observation" not in x: + raise Exception("State must contain an observation") + if "reward" not in x: + x["reward"] = 0.0 + if "done" not in x: + x["done"] = False + if "mask" not in x: + x["mask"] = 1.0 - x["done"] super().__init__(x) self._shape = () self.device = device @@ -71,17 +71,33 @@ def array(cls, list_of_states): v = list_of_states[0][key] try: if isinstance(v, list) and len(v) > 0 and torch.is_tensor(v[0]): - x[key] = torch.stack([torch.stack(state[key]) for state in list_of_states]) + x[key] = torch.stack( + [torch.stack(state[key]) for state in list_of_states] + ) elif torch.is_tensor(v): x[key] = torch.stack([state[key] for state in list_of_states]) else: - x[key] = torch.tensor([state[key] for state in list_of_states], device=device) + x[key] = torch.tensor( + [state[key] for state in list_of_states], device=device + ) except KeyError: - warnings.warn('KeyError while creating StateArray for key "{}", omitting.'.format(key)) + warnings.warn( + 'KeyError while creating StateArray for key "{}", omitting.'.format( + key + ) + ) except ValueError: - warnings.warn('ValueError while creating StateArray for key "{}", omitting.'.format(key)) + warnings.warn( + 'ValueError while creating StateArray for key "{}", omitting.'.format( + key + ) + ) except TypeError: - warnings.warn('TypeError while creating StateArray for key "{}", omitting.'.format(key)) + warnings.warn( + 'TypeError while creating StateArray for key "{}", omitting.'.format( + key + ) + ) return StateArray(x, shape, device=device) @@ -100,7 +116,9 @@ def apply(self, model, *keys): Returns: The output of the model. """ - return self.apply_mask(self.as_output(model(*[self.as_input(key) for key in keys]))) + return self.apply_mask( + self.as_output(model(*[self.as_input(key) for key in keys])) + ) def as_input(self, key): """ @@ -158,7 +176,7 @@ def update(self, key, value): return self.__class__(x, device=self.device) @classmethod - def from_gym(cls, gym_output, device='cpu', dtype=np.float32): + def from_gym(cls, gym_output, device="cpu", dtype=np.float32): """ Constructs a State object given the return value of an OpenAI gym reset()/step(action) call. @@ -170,27 +188,35 @@ def from_gym(cls, gym_output, device='cpu', dtype=np.float32): Returns: A State object. """ - if not isinstance(gym_output, tuple) and (len(gym_output) == 2 or len(gym_output) == 5): - raise TypeError(f"gym_output should be a tuple, either (observation, info) or (observation, reward, terminated, truncated, info). Recieved {gym_output}.") - # extract info from timestep - if len(gym_output) == 5: + if isinstance(gym_output, tuple) and len(gym_output) == 5: + # gymanisum step() observation, reward, terminated, truncated, info = gym_output - if len(gym_output) == 2: + elif isinstance(gym_output, tuple) and len(gym_output) == 4: + # legacy gym step() + observation, reward, done, info = gym_output + terminated = done + truncated = False + elif isinstance(gym_output, tuple) and len(gym_output) == 2: + # gymnasium reset() observation, info = gym_output - reward = 0. + reward = 0.0 terminated = False truncated = False + else: + # legacy gym reset() + observation = gym_output + reward = 0.0 + terminated = False + truncated = False + info = {} x = { - 'observation': torch.from_numpy( - np.array( - observation, - dtype=dtype - ), - ).to(device), - 'reward': float(reward), - 'done': terminated or truncated, - 'mask': 1. - terminated + "observation": torch.from_numpy( + np.array(observation, dtype=dtype), + ).to(device), + "reward": float(reward), + "done": terminated or truncated, + "mask": 1.0 - terminated, } info = info if info else {} for key in info: @@ -211,22 +237,22 @@ def to(self, device): @property def observation(self): """A tensor containing the current observation.""" - return self['observation'] + return self["observation"] @property def reward(self): """A float representing the reward for the previous state/action pair.""" - return self['reward'] + return self["reward"] @property def done(self): """A boolean that is true if the state is a terminal state, and false otherwise.""" - return self['done'] + return self["done"] @property def mask(self): """A float that is 1. if the state is non-terminal, or 0. otherwise.""" - return self['mask'] + return self["mask"] @property def shape(self): @@ -239,47 +265,49 @@ def __len__(self): class StateArray(State): """ - An n-dimensional array of environment State objects. + An n-dimensional array of environment State objects. - Internally, all components of the states are represented as n-dimensional tensors. - This allows for batch-style processing and easy manipulation of states. - Usually, a StateArray should be constructed using the State.array() function. + Internally, all components of the states are represented as n-dimensional tensors. + This allows for batch-style processing and easy manipulation of states. + Usually, a StateArray should be constructed using the State.array() function. - Args: - x (dict): - A dictionary containing all state information. - Each value should be a tensor in which the first n-dimensions - match the shape of the StateArray. - The following keys are standard: + Args: + x (dict): + A dictionary containing all state information. + Each value should be a tensor in which the first n-dimensions + match the shape of the StateArray. + The following keys are standard: - observation (torch.tensor) (required): - A tensor representing the observations for each state + observation (torch.tensor) (required): + A tensor representing the observations for each state - reward (torch.FloatTensor) (optional): - A tensor representing rewards for the previous state/action pairs + reward (torch.FloatTensor) (optional): + A tensor representing rewards for the previous state/action pairs - done (torch.BoolTensors) (optional): - A tensor representing whether each state is terminal + done (torch.BoolTensors) (optional): + A tensor representing whether each state is terminal - mask (torch.FloatTensor) (optional): - A tensor representing the mask for each state. - device (string): - The torch device on which component tensors are stored. + mask (torch.FloatTensor) (optional): + A tensor representing the mask for each state. + device (string): + The torch device on which component tensors are stored. """ - def __init__(self, x, shape, device='cpu', **kwargs): + def __init__(self, x, shape, device="cpu", **kwargs): if not isinstance(x, dict): - x = {'observation': x} + x = {"observation": x} for k, v in kwargs.items(): x[k] = v - if 'observation' not in x: - raise Exception('StateArray must contain an observation') - if 'reward' not in x: - x['reward'] = torch.zeros(shape, device=device) - if 'done' not in x: - x['done'] = torch.tensor([False] * np.prod(shape), device=device).view(shape) - if 'mask' not in x: - x['mask'] = 1. - x['done'].float() + if "observation" not in x: + raise Exception("StateArray must contain an observation") + if "reward" not in x: + x["reward"] = torch.zeros(shape, device=device) + if "done" not in x: + x["done"] = torch.tensor([False] * np.prod(shape), device=device).view( + shape + ) + if "mask" not in x: + x["mask"] = 1.0 - x["done"].float() super().__init__(x, device=device) self._shape = shape @@ -305,7 +333,9 @@ def update(self, key, value): def as_input(self, key): value = self[key] - return value.view((np.prod(self.shape), *value.shape[len(self.shape):])).float() + return value.view( + (np.prod(self.shape), *value.shape[len(self.shape):]) + ).float() def as_output(self, tensor): return tensor.view((*self.shape, *tensor.shape[1:])) @@ -343,31 +373,33 @@ def view(self, shape): @property def observation(self): - return self['observation'] + return self["observation"] @property def reward(self): - return self['reward'] + return self["reward"] @property def done(self): - return self['done'] + return self["done"] @property def mask(self): - return self['mask'] + return self["mask"] def __getitem__(self, key): if isinstance(key, slice) or isinstance(key, int): - shape = self['mask'][key].shape + shape = self["mask"][key].shape if len(shape) == 0: return State({k: v[key] for (k, v) in self.items()}, device=self.device) - return StateArray({k: v[key] for (k, v) in self.items()}, shape, device=self.device) + return StateArray( + {k: v[key] for (k, v) in self.items()}, shape, device=self.device + ) if torch.is_tensor(key): # some things may get lost d = {} - shape = self['mask'][key].shape - for (k, v) in self.items(): + shape = self["mask"][key].shape + for k, v in self.items(): try: d[k] = v[key] except KeyError: @@ -389,7 +421,7 @@ def __len__(self): @classmethod def cat(cls, state_array_list, axis=0): - '''Concatenates along batch dimention''' + """Concatenates along batch dimention""" if len(state_array_list) == 0: raise ValueError("cat accepts a non-zero size list of StateArrays") @@ -400,13 +432,15 @@ def cat(cls, state_array_list, axis=0): new_shape = tuple(new_shape) keys = list(state_array_list[0].keys()) for key in keys: - d[key] = torch.cat([state_array[key] for state_array in state_array_list], axis=axis) + d[key] = torch.cat( + [state_array[key] for state_array in state_array_list], axis=axis + ) return StateArray(d, new_shape, device=state_array_list[0].device) def batch_execute(self, minibatch_size, fn): - ''' + """ execute in batches to reduce memory consumption - ''' + """ data = self batch_size = self.shape[0] results = [] @@ -424,17 +458,17 @@ def batch_execute(self, minibatch_size, fn): class MultiagentState(State): - def __init__(self, x, device='cpu', **kwargs): - if 'agent' not in x: - raise Exception('MultiagentState must contain an agent ID') + def __init__(self, x, device="cpu", **kwargs): + if "agent" not in x: + raise Exception("MultiagentState must contain an agent ID") super().__init__(x, device=device, **kwargs) @property def agent(self): - return self['agent'] + return self["agent"] @classmethod - def from_zoo(cls, agent, state, device='cpu', dtype=np.float32): + def from_zoo(cls, agent, state, device="cpu", dtype=np.float32): """ Constructs a State object given the return value of an OpenAI gym reset()/step(action) call. @@ -446,29 +480,15 @@ def from_zoo(cls, agent, state, device='cpu', dtype=np.float32): Returns: A State object. """ - if not isinstance(state, tuple): - return MultiagentState({ - 'agent': agent, - 'observation': torch.from_numpy( - np.array( - state, - dtype=dtype - ), - ).to(device) - }, device=device) - - observation, reward, done, info = state - observation = torch.from_numpy( - np.array( - observation, - dtype=dtype - ), - ).to(device) + observation, reward, terminated, truncated, info = state x = { - 'agent': agent, - 'observation': observation, - 'reward': float(reward), - 'done': done, + "agent": agent, + "observation": torch.from_numpy( + np.array(observation, dtype=dtype), + ).to(device), + "reward": float(reward), + "done": terminated or truncated, + "mask": 1.0 - terminated, } info = info if info else {} for key in info: diff --git a/all/core/state_test.py b/all/core/state_test.py index 0de2814b..597563da 100644 --- a/all/core/state_test.py +++ b/all/core/state_test.py @@ -77,6 +77,26 @@ def test_from_truncated_gym_step(self): self.assertEqual(state['coolInfo'], 3.) self.assertEqual(state.shape, ()) + def test_legacy_gym_step(self): + observation = np.array([1, 2, 3]) + state = State.from_gym((observation, 2., True, {'coolInfo': 3.})) + tt.assert_equal(state.observation, torch.from_numpy(observation)) + self.assertEqual(state.mask, 0.) + self.assertEqual(state.done, True) + self.assertEqual(state.reward, 2.) + self.assertEqual(state['coolInfo'], 3.) + self.assertEqual(state.shape, ()) + + def test_from_truncated_gym_step(self): + observation = np.array([1, 2, 3]) + state = State.from_gym((observation, 2., False, True, {'coolInfo': 3.})) + tt.assert_equal(state.observation, torch.from_numpy(observation)) + self.assertEqual(state.mask, 1.) + self.assertEqual(state.done, True) + self.assertEqual(state.reward, 2.) + self.assertEqual(state['coolInfo'], 3.) + self.assertEqual(state.shape, ()) + def test_as_input(self): observation = torch.randn(3, 4) state = State(observation) diff --git a/all/environments/atari.py b/all/environments/atari.py index 8db1606a..204fbbdf 100644 --- a/all/environments/atari.py +++ b/all/environments/atari.py @@ -37,8 +37,7 @@ def __init__(self, name, device='cpu'): self._device = device def reset(self): - state = self._env.reset(), 0., False, None - self._state = State.from_gym(state, dtype=self._env.observation_space.dtype, device=self._device) + self._state = State.from_gym(self._env.reset(), dtype=self._env.observation_space.dtype, device=self._device) return self._state def step(self, action): diff --git a/all/environments/atari_wrappers.py b/all/environments/atari_wrappers.py index 71e05a2a..df6013ef 100644 --- a/all/environments/atari_wrappers.py +++ b/all/environments/atari_wrappers.py @@ -34,8 +34,8 @@ def reset(self, **kwargs): assert noops > 0 obs = None for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) - if done: + obs, _, terminated, truncated, _ = self.env.step(self.noop_action) + if terminated or truncated: obs = self.env.reset(**kwargs) return obs @@ -58,26 +58,25 @@ def __init__(self, env): def reset(self, **kwargs): self.env.reset(**kwargs) - obs, _ = self.fire() + obs, info = self.fire() self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, info def step(self, action): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) if self.lost_life(): - obs, done = self.fire() + obs, info = self.fire() self.lives = self.env.unwrapped.ale.lives() - return obs, reward, done, info + return obs, reward, terminated, truncated, info def fire(self): - obs, _, done, _ = self.env.step(1) - if done: + obs, _, terminated, truncated, info = self.env.step(1) + if terminated or truncated: self.env.reset() - obs, _, done, _ = self.env.step(2) - if done: - obs = self.env.reset() - done = False - return obs, done + obs, _, terminated, truncated, info = self.env.step(2) + if terminated or truncated: + obs, info = self.env.reset() + return obs, info def lost_life(self): lives = self.env.unwrapped.ale.lives() @@ -95,21 +94,20 @@ def __init__(self, env, skip=4): def step(self, action): '''Repeat action, sum reward, and max over last observations.''' total_reward = 0.0 - done = None for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs total_reward += reward - if done: + if terminated or truncated: break # Note that the observation on the done=True frame # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, terminated, truncated, info def reset(self, **kwargs): return self.env.reset(**kwargs) @@ -183,9 +181,9 @@ def reset(self): return self.env.reset() def step(self, action): - obs, reward, done, _ = self.env.step(action) + obs, reward, terminated, truncated, _ = self.env.step(action) lives = self.env.unwrapped.ale.lives() life_lost = (lives < self.lives and lives > 0) self.lives = lives info = {'life_lost': life_lost} - return obs, reward, done, info + return obs, reward, terminated, truncated, info diff --git a/all/environments/duplicate_env.py b/all/environments/duplicate_env.py index 0f7cec42..42fbd3b4 100644 --- a/all/environments/duplicate_env.py +++ b/all/environments/duplicate_env.py @@ -32,8 +32,11 @@ def __init__(self, envs, device=torch.device('cpu')): def name(self): return self._name - def reset(self): - self._state = State.array([sub_env.reset() for sub_env in self._envs]) + def reset(self, seed=None, **kwargs): + if seed is not None: + self._state = State.array([sub_env.reset(seed=(seed + i), **kwargs) for i, sub_env in enumerate(self._envs)]) + else: + self._state = State.array([sub_env.reset(**kwargs) for sub_env in self._envs]) return self._state def step(self, actions): @@ -48,10 +51,6 @@ def step(self, actions): def close(self): return self._env.close() - def seed(self, seed): - for i, env in enumerate(self._envs): - env.seed(seed + i) - @property def state_space(self): return self._envs[0].observation_space diff --git a/all/environments/duplicate_env_test.py b/all/environments/duplicate_env_test.py index 23c2c66e..1c0c750e 100644 --- a/all/environments/duplicate_env_test.py +++ b/all/environments/duplicate_env_test.py @@ -42,7 +42,6 @@ def test_step(self): def test_step_until_done(self): num_envs = 3 env = DuplicateEnvironment(make_vec_env(num_envs)) - env.seed(5) env.reset() for _ in range(100): state = env.step(torch.ones(num_envs, dtype=torch.int32)) diff --git a/all/environments/gym.py b/all/environments/gym.py index 73b820d1..dcc82b11 100644 --- a/all/environments/gym.py +++ b/all/environments/gym.py @@ -21,10 +21,15 @@ class GymEnvironment(Environment): env: Either a string or an OpenAI gym environment name (str, optional): the name of the environment device (str, optional): the device on which tensors will be stored + legacy_gym (str, optional): If true, calls gym.make() instead of gymnasium.make() ''' - def __init__(self, id, device=torch.device('cpu'), name=None): - self._env = gymnasium.make(id) + def __init__(self, id, device=torch.device('cpu'), name=None, legacy_gym=False): + if legacy_gym: + import gym + self._env = gym.make(id) + else: + self._env = gymnasium.make(id) self._id = id self._name = name if name else id self._state = None @@ -38,9 +43,8 @@ def __init__(self, id, device=torch.device('cpu'), name=None): def name(self): return self._name - def reset(self): - state = self._env.reset(), 0., False, None - self._state = State.from_gym(state, dtype=self._env.observation_space.dtype, device=self._device) + def reset(self, **kwargs): + self._state = State.from_gym(self._env.reset(**kwargs), dtype=self._env.observation_space.dtype, device=self._device) return self._state def step(self, action): diff --git a/all/environments/multiagent_atari.py b/all/environments/multiagent_atari.py index 626c92d9..e9a10de1 100644 --- a/all/environments/multiagent_atari.py +++ b/all/environments/multiagent_atari.py @@ -25,10 +25,10 @@ def __init__(self, env_name, device='cuda', **pettingzoo_params): def _load_env(self, env_name, pettingzoo_params): from pettingzoo import atari - from supersuit import resize_v0, frame_skip_v0, reshape_v0, max_observation_v0 + from supersuit import resize_v1, frame_skip_v0, reshape_v0, max_observation_v0 env = importlib.import_module('pettingzoo.atari.{}'.format(env_name)).env(obs_type='grayscale_image', **pettingzoo_params) env = max_observation_v0(env, 2) env = frame_skip_v0(env, 4) - env = resize_v0(env, 84, 84) + env = resize_v1(env, 84, 84) env = reshape_v0(env, (1, 84, 84)) return env diff --git a/all/environments/multiagent_atari_test.py b/all/environments/multiagent_atari_test.py index 251e98f4..c5c5e7a8 100644 --- a/all/environments/multiagent_atari_test.py +++ b/all/environments/multiagent_atari_test.py @@ -5,12 +5,12 @@ class MultiagentAtariEnvTest(unittest.TestCase): def test_init(self): - MultiagentAtariEnv('pong_v2', device='cpu') - MultiagentAtariEnv('mario_bros_v2', device='cpu') - MultiagentAtariEnv('entombed_cooperative_v2', device='cpu') + MultiagentAtariEnv('pong_v3', device='cpu') + MultiagentAtariEnv('mario_bros_v3', device='cpu') + MultiagentAtariEnv('entombed_cooperative_v3', device='cpu') def test_reset(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') state = env.reset() self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) @@ -19,7 +19,7 @@ def test_reset(self): self.assertEqual(state['agent'], 'first_0') def test_step(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') env.reset() state = env.step(0) self.assertEqual(state.observation.shape, (1, 84, 84)) @@ -29,7 +29,7 @@ def test_step(self): self.assertEqual(state['agent'], 'second_0') def test_step_tensor(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') env.reset() state = env.step(torch.tensor([0])) self.assertEqual(state.observation.shape, (1, 84, 84)) @@ -39,37 +39,37 @@ def test_step_tensor(self): self.assertEqual(state['agent'], 'second_0') def test_name(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') - self.assertEqual(env.name, 'pong_v2') + env = MultiagentAtariEnv('pong_v3', device='cpu') + self.assertEqual(env.name, 'pong_v3') def test_agent_iter(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') env.reset() it = iter(env.agent_iter()) self.assertEqual(next(it), 'first_0') def test_state_spaces(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') self.assertEqual(env.state_space('first_0').shape, (1, 84, 84)) self.assertEqual(env.state_space('second_0').shape, (1, 84, 84)) def test_action_spaces(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') - self.assertEqual(env.action_space('first_0').n, 18) - self.assertEqual(env.action_space('second_0').n, 18) + env = MultiagentAtariEnv('pong_v3', device='cpu') + self.assertEqual(env.action_space('first_0').n, 6) + self.assertEqual(env.action_space('second_0').n, 6) def test_list_agents(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') self.assertEqual(env.agents, ['first_0', 'second_0']) def test_is_done(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') env.reset() self.assertFalse(env.is_done('first_0')) self.assertFalse(env.is_done('second_0')) def test_last(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') + env = MultiagentAtariEnv('pong_v3', device='cpu') env.reset() state = env.last() self.assertEqual(state.observation.shape, (1, 84, 84)) diff --git a/all/environments/multiagent_pettingzoo.py b/all/environments/multiagent_pettingzoo.py index 7e1716a3..3376e519 100644 --- a/all/environments/multiagent_pettingzoo.py +++ b/all/environments/multiagent_pettingzoo.py @@ -38,8 +38,8 @@ def __init__(self, zoo_env, name, device='cuda'): An initial MultiagentState object. ''' - def reset(self): - self._env.reset() + def reset(self, **kwargs): + self._env.reset(**kwargs) return self.last() ''' @@ -72,15 +72,15 @@ def agent_iter(self): return self._env.agent_iter() def is_done(self, agent): - return self._env.dones[agent] + return self._env.terminations[agent] def duplicate(self, n): return [MultiagentPettingZooEnv(cloudpickle.loads(cloudpickle.dumps(self._env)), self._name, device=self.device) for _ in range(n)] def last(self): - observation, reward, done, info = self._env.last() + observation, reward, terminated, truncated, info = self._env.last() selected_obs_space = self._env.observation_space(self._env.agent_selection) - return MultiagentState.from_zoo(self._env.agent_selection, (observation, reward, done, info), device=self._device, dtype=selected_obs_space.dtype) + return MultiagentState.from_zoo(self._env.agent_selection, (observation, reward, terminated, truncated, info), device=self._device, dtype=selected_obs_space.dtype) @property def name(self): diff --git a/all/environments/multiagent_pettingzoo_test.py b/all/environments/multiagent_pettingzoo_test.py index d2e4c7df..482e50e5 100644 --- a/all/environments/multiagent_pettingzoo_test.py +++ b/all/environments/multiagent_pettingzoo_test.py @@ -1,7 +1,7 @@ import unittest import torch from all.environments import MultiagentPettingZooEnv -from pettingzoo.mpe import simple_world_comm_v2 +from pettingzoo.mpe import simple_world_comm_v3 class MultiagentPettingZooEnvTest(unittest.TestCase): @@ -39,7 +39,7 @@ def test_step_tensor(self): def test_name(self): env = self._make_env() - self.assertEqual(env.name, 'simple_world_comm_v2') + self.assertEqual(env.name, 'simple_world_comm_v3') def test_agent_iter(self): env = self._make_env() @@ -61,7 +61,7 @@ def test_list_agents(self): env = self._make_env() self.assertEqual(env.agents, ['leadadversary_0', 'adversary_0', 'adversary_1', 'adversary_2', 'agent_0', 'agent_1']) - def test_is_done(self): + def test_terminated(self): env = self._make_env() env.reset() self.assertFalse(env.is_done('leadadversary_0')) @@ -78,7 +78,7 @@ def test_last(self): self.assertEqual(state['agent'], 'leadadversary_0') def test_variable_spaces(self): - env = MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu') + env = MultiagentPettingZooEnv(simple_world_comm_v3.env(), name="simple_world_comm_v2", device='cpu') state = env.reset() # tests that action spaces work for agent in env.agents: @@ -87,7 +87,7 @@ def test_variable_spaces(self): env.step(env.action_space(env.agent_selection).sample()) def _make_env(self): - return MultiagentPettingZooEnv(simple_world_comm_v2.env(), name="simple_world_comm_v2", device='cpu') + return MultiagentPettingZooEnv(simple_world_comm_v3.env(), name="simple_world_comm_v3", device='cpu') if __name__ == "__main__": diff --git a/all/environments/pybullet.py b/all/environments/pybullet.py index a986e5b4..70f379fb 100644 --- a/all/environments/pybullet.py +++ b/all/environments/pybullet.py @@ -14,4 +14,4 @@ def __init__(self, name, **kwargs): import pybullet_envs if name in self.short_names: name = self.short_names[name] - super().__init__(name, **kwargs) + super().__init__(name, legacy_gym=True, **kwargs) diff --git a/all/environments/vector_env.py b/all/environments/vector_env.py index b57150e9..d74a46af 100644 --- a/all/environments/vector_env.py +++ b/all/environments/vector_env.py @@ -35,16 +35,16 @@ def __init__(self, vec_env, name, device=torch.device('cpu')): def name(self): return self._name - def reset(self): - state_tuple = self._env.reset(), np.zeros(self._env.num_envs), np.zeros(self._env.num_envs), None - self._state = self._to_state(*state_tuple) + def reset(self, **kwargs): + obs, info = self._env.reset(**kwargs) + self._state = self._to_state(obs, np.zeros(self._env.num_envs), np.zeros(self._env.num_envs), np.zeros(self._env.num_envs), info) return self._state - def _to_state(self, obs, rew, done, info): + def _to_state(self, obs, rew, terminated, truncated, info): obs = obs.astype(self.observation_space.dtype) rew = rew.astype("float32") - done = done.astype("bool") - mask = (1 - done).astype("float32") + done = (terminated + truncated).astype("bool") + mask = (1 - terminated).astype("float32") return StateArray({ "observation": torch.tensor(obs, device=self._device), "reward": torch.tensor(rew, device=self._device), @@ -60,9 +60,6 @@ def step(self, action): def close(self): return self._env.close() - def seed(self, seed): - self._env.seed(seed) - @property def state_space(self): return getattr(self._env, "single_observation_space", getattr(self._env, "observation_space")) diff --git a/all/environments/vector_env_test.py b/all/environments/vector_env_test.py index 3eaaa864..a4cfba77 100644 --- a/all/environments/vector_env_test.py +++ b/all/environments/vector_env_test.py @@ -42,8 +42,7 @@ def test_step(self): def test_step_until_done(self): num_envs = 3 env = GymVectorEnvironment(make_vec_env(num_envs), "CartPole") - env.seed(5) - env.reset() + env.reset(seed=5) for _ in range(100): state = env.step(torch.ones(num_envs, dtype=torch.int32)) if state.done[0]: @@ -60,10 +59,8 @@ def test_same_as_duplicate(self): torch.manual_seed(42) env1 = DuplicateEnvironment([GymEnvironment('CartPole-v0') for i in range(n_envs)]) env2 = GymVectorEnvironment(make_vec_env(n_envs), "CartPole-v0") - env1.seed(42) - env2.seed(42) - state1 = env1.reset() - state2 = env2.reset() + state1 = env1.reset(seed=42) + state2 = env2.reset(seed=42) self.assertEqual(env1.name, env2.name) self.assertEqual(env1.action_space.n, env2.action_space.n) self.assertEqual(env1.observation_space.shape, env2.observation_space.shape) diff --git a/all/experiments/multiagent_env_experiment_test.py b/all/experiments/multiagent_env_experiment_test.py index 5e7a5a4b..ceb8a47a 100644 --- a/all/experiments/multiagent_env_experiment_test.py +++ b/all/experiments/multiagent_env_experiment_test.py @@ -18,24 +18,25 @@ class TestMultiagentEnvExperiment(unittest.TestCase): def setUp(self): np.random.seed(0) torch.manual_seed(0) - self.env = MultiagentAtariEnv('space_invaders_v1', device='cpu') - self.env.seed(0) + self.env = MultiagentAtariEnv('space_invaders_v2', device='cpu') + self.env.reset(seed=0) self.experiment = None def test_adds_default_name(self): experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf')) - self.assertEqual(experiment._logger.label, "independent_space_invaders_v1") + self.assertEqual(experiment._logger.label, "independent_space_invaders_v2") def test_adds_custom_name(self): experiment = MockExperiment(self.make_preset(), self.env, name='custom', quiet=True, save_freq=float('inf')) - self.assertEqual(experiment._logger.label, "custom_space_invaders_v1") + self.assertEqual(experiment._logger.label, "custom_space_invaders_v2") def test_writes_training_returns(self): experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf')) experiment.train(episodes=3) + self.maxDiff = None self.assertEqual(experiment._logger.data, { - 'eval/first_0/returns/frame': {'values': [465.0, 235.0, 735.0, 415.0], 'steps': [766, 1524, 2440, 3038]}, - 'eval/second_0/returns/frame': {'values': [235.0, 465.0, 170.0, 295.0], 'steps': [766, 1524, 2440, 3038]} + 'eval/first_0/returns/frame': {'values': [705.0, 490.0, 230.0, 435.0], 'steps': [808, 1580, 2120, 3300]}, + 'eval/second_0/returns/frame': {'values': [115.0, 525.0, 415.0, 665.0], 'steps': [808, 1580, 2120, 3300]} }) def test_writes_test_returns(self): diff --git a/all/experiments/parallel_env_experiment_test.py b/all/experiments/parallel_env_experiment_test.py index cf7c343d..28a9558b 100644 --- a/all/experiments/parallel_env_experiment_test.py +++ b/all/experiments/parallel_env_experiment_test.py @@ -9,7 +9,7 @@ class MockExperiment(ParallelEnvExperiment): def _make_logger(self, logdir, agent_name, env_name, verbose, logger): - self._logger = MockLogger(self, agent_name + '_' + env_name, verbose) + self._logger = MockLogger(self, agent_name + "_" + env_name, verbose) return self._logger @@ -17,28 +17,28 @@ class TestParallelEnvExperiment(unittest.TestCase): def setUp(self): np.random.seed(0) torch.manual_seed(0) - self.env = GymEnvironment('CartPole-v0') - self.env.seed(0) + self.env = GymEnvironment("CartPole-v0") + self.env.reset(seed=0) self.experiment = MockExperiment(self.make_agent(), self.env, quiet=True) - self.experiment._env.seed(0) + self.experiment._env.reset(seed=0) def test_adds_default_label(self): self.assertEqual(self.experiment._logger.label, "a2c_CartPole-v0") def test_adds_custom_label(self): - env = GymEnvironment('CartPole-v0') - experiment = MockExperiment(self.make_agent(), env, name='a2c', quiet=True) + env = GymEnvironment("CartPole-v0") + experiment = MockExperiment(self.make_agent(), env, name="a2c", quiet=True) self.assertEqual(experiment._logger.label, "a2c_CartPole-v0") def test_writes_training_returns_eps(self): - self.experiment.train(episodes=3) + self.experiment.train(episodes=4) np.testing.assert_equal( self.experiment._logger.data["eval/returns/episode"]["steps"], - np.array([1, 2, 3]), + np.array([1, 2, 3, 3]), ) np.testing.assert_equal( self.experiment._logger.data["eval/returns/episode"]["values"], - np.array([10., 12., 19.]), + np.array([12.0, 13.0, 16.0, 16.0]), ) def test_writes_test_returns(self): @@ -55,13 +55,17 @@ def test_writes_test_returns(self): ) def test_writes_loss(self): - experiment = MockExperiment(self.make_agent(), self.env, quiet=True, verbose=True) + experiment = MockExperiment( + self.make_agent(), self.env, quiet=True, verbose=True + ) self.assertTrue(experiment._logger.verbose) - experiment = MockExperiment(self.make_agent(), self.env, quiet=True, verbose=False) + experiment = MockExperiment( + self.make_agent(), self.env, quiet=True, verbose=False + ) self.assertFalse(experiment._logger.verbose) def make_agent(self): - return a2c.device('cpu').env(self.env).build() + return a2c.device("cpu").env(self.env).build() if __name__ == "__main__": diff --git a/all/experiments/single_env_experiment_test.py b/all/experiments/single_env_experiment_test.py index f5f03789..6c11b04b 100644 --- a/all/experiments/single_env_experiment_test.py +++ b/all/experiments/single_env_experiment_test.py @@ -58,7 +58,7 @@ def setUp(self): np.random.seed(0) torch.manual_seed(0) self.env = GymEnvironment('CartPole-v0') - self.env.seed(0) + self.env.reset(seed=0) self.experiment = None def test_adds_default_name(self): @@ -74,7 +74,7 @@ def test_writes_training_returns_eps(self): experiment.train(episodes=3) np.testing.assert_equal( experiment._logger.data["eval/returns/episode"]["values"], - np.array([18., 23., 27.]), + np.array([22., 17., 28.]), ) np.testing.assert_equal( experiment._logger.data["eval/returns/episode"]["steps"], @@ -85,8 +85,8 @@ def test_writes_test_returns(self): experiment = MockExperiment(self.make_preset(), self.env, quiet=True) experiment.train(episodes=5) returns = experiment.test(episodes=4) - expected_mean = 8.75 - expected_std = 0.433013 + expected_mean = 8.5 + expected_std = 0.5 np.testing.assert_equal(np.mean(returns), expected_mean) np.testing.assert_equal( experiment._logger.data["summary/returns-test/mean"]["values"], @@ -99,7 +99,7 @@ def test_writes_test_returns(self): ) np.testing.assert_equal( experiment._logger.data["summary/returns-test/mean"]["steps"], - np.array([94]), + np.array([93]), ) def test_writes_loss(self): diff --git a/all/presets/atari/models/test_.py b/all/presets/atari/models/test_.py deleted file mode 100644 index 68b1360d..00000000 --- a/all/presets/atari/models/test_.py +++ /dev/null @@ -1,146 +0,0 @@ -import unittest -import torch -import torch_testing as tt -from all.environments import AtariEnvironment -from all.presets.atari.models import nature_rainbow - - -class TestAtariModels(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - def test_rainbow_model_cpu(self): - env = AtariEnvironment('Breakout') - model = nature_rainbow(env) - env.reset() - x = torch.cat([env.state.raw] * 4, dim=1).float() - out = model(x) - tt.assert_almost_equal( - out, - torch.tensor([[ - 0.0676, -0.0235, 0.0690, -0.0713, -0.0287, 0.0053, -0.0463, 0.0495, - -0.0222, -0.0504, 0.0064, -0.0204, 0.0168, 0.0127, -0.0113, -0.0586, - -0.0544, 0.0114, -0.0077, 0.0666, -0.0663, -0.0420, -0.0698, -0.0314, - 0.0272, 0.0361, -0.0537, 0.0301, 0.0036, -0.0472, -0.0499, 0.0114, - 0.0182, 0.0008, -0.0132, -0.0803, -0.0087, -0.0017, 0.0598, -0.0627, - 0.0859, 0.0117, 0.0105, 0.0309, -0.0370, -0.0111, -0.0262, 0.0338, - 0.0141, -0.0385, 0.0547, 0.0648, -0.0370, 0.0107, -0.0629, -0.0163, - 0.0282, -0.0670, 0.0161, -0.0244, -0.0030, 0.0038, -0.0208, 0.0005, - 0.0125, 0.0608, -0.0089, 0.0026, 0.0562, -0.0678, 0.0841, -0.0265, - -0.0461, -0.0124, 0.0276, 0.0364, 0.0195, -0.0309, -0.0337, -0.0603, - -0.0252, -0.0356, 0.0221, 0.0184, -0.0154, -0.0136, -0.0277, 0.0283, - 0.0495, 0.0185, -0.0357, 0.0305, -0.0052, -0.0432, -0.0135, -0.0554, - -0.0094, 0.0272, 0.1030, 0.0049, 0.0012, -0.0140, 0.0146, -0.0979, - 0.0487, 0.0122, -0.0204, 0.0496, -0.0055, -0.0015, -0.0170, 0.0053, - 0.0104, -0.0742, 0.0742, -0.0381, 0.0104, -0.0065, -0.0564, 0.0453, - -0.0057, -0.0029, -0.0722, 0.0094, -0.0561, 0.0284, 0.0402, 0.0233, - -0.0716, -0.0424, 0.0165, -0.0505, 0.0006, 0.0219, -0.0601, 0.0656, - -0.0175, -0.0524, 0.0355, 0.0007, -0.0042, -0.0443, 0.0871, -0.0403, - -0.0031, 0.0171, -0.0359, -0.0520, -0.0344, 0.0239, 0.0099, 0.0004, - 0.0235, 0.0238, -0.0153, 0.0501, -0.0052, 0.0162, 0.0313, -0.0121, - 0.0009, -0.0366, -0.0628, 0.0386, -0.0671, 0.0480, -0.0595, 0.0568, - -0.0604, -0.0540, 0.0403, -0.0187, 0.0649, 0.0029, -0.0003, 0.0020, - -0.0056, 0.0471, -0.0145, -0.0126, -0.0395, -0.0455, -0.0437, 0.0056, - 0.0331, 0.0004, 0.0127, -0.0022, -0.0502, 0.0362, 0.0624, -0.0012, - -0.0515, 0.0303, -0.0357, -0.0420, 0.0321, -0.0162, 0.0007, -0.0272, - 0.0227, 0.0187, -0.0459, 0.0496 - ]]), - decimal=3 - ) - - def test_rainbow_model_cuda(self): - env = AtariEnvironment('Breakout') - model = nature_rainbow(env).cuda() - env.reset() - x = torch.cat([env.state.raw] * 4, dim=1).float().cuda() - out = model(x) - tt.assert_almost_equal( - out.cpu(), - torch.tensor([[ - -1.4765e-02, -4.0353e-02, -2.1705e-02, -2.2314e-02, 3.6881e-02, - -1.4175e-02, 1.2442e-02, -6.8713e-03, 2.4970e-02, 2.5681e-02, - -4.5859e-02, -2.3327e-02, 3.6205e-02, 7.1024e-03, -2.7564e-02, - 2.1592e-02, -3.2728e-02, 1.3602e-02, -1.1690e-02, -4.3082e-02, - -1.2996e-02, 1.7184e-02, 1.3446e-02, -3.3587e-03, -4.6350e-02, - -1.7646e-02, 2.1954e-02, 8.5546e-03, -2.1359e-02, -2.4206e-02, - -2.3151e-02, -3.6330e-02, 4.4699e-02, 3.9887e-03, 1.5609e-02, - -4.3950e-02, 1.0955e-02, -2.4277e-02, 1.4915e-02, 3.2508e-03, - 6.1454e-02, 3.5242e-02, -1.5274e-02, -2.6729e-02, -2.4072e-02, - 1.5696e-02, 2.6622e-02, -3.5404e-02, 5.1701e-02, -5.3047e-02, - -1.8412e-02, 8.6640e-03, -3.1722e-02, 4.0329e-02, 1.2896e-02, - -1.4139e-02, -4.9200e-02, -4.6193e-02, -2.9064e-03, -2.2078e-02, - -4.0084e-02, -8.3519e-03, -2.7589e-02, -4.9979e-03, -1.6055e-02, - -4.5311e-02, -2.6951e-02, 2.8032e-02, -4.0069e-03, 3.2405e-02, - -5.3164e-03, -3.0139e-03, 6.6179e-04, -4.9243e-02, 3.2515e-02, - 9.8307e-03, -3.4257e-03, -3.9522e-02, 1.2594e-02, -2.7210e-02, - 2.3451e-02, 4.2257e-02, 2.2239e-02, 1.4304e-04, 4.2905e-04, - 1.5193e-02, 3.1897e-03, -1.0828e-02, -4.8345e-02, 6.8747e-02, - -7.1725e-03, -9.7815e-03, -1.6331e-02, 1.0434e-02, -8.8083e-04, - 3.8219e-02, 6.8332e-03, -2.0189e-02, 2.8141e-02, 1.4913e-02, - -2.4925e-02, -2.8922e-02, -7.1546e-03, 1.9791e-02, 1.1160e-02, - 1.0306e-02, -1.3631e-02, 2.7318e-03, 1.4050e-03, -8.2064e-03, - 3.5836e-02, -1.5877e-02, -1.1198e-02, 1.9514e-02, 3.0832e-03, - -6.2730e-02, 6.1493e-03, -1.2340e-02, 3.9110e-02, -2.6895e-02, - -5.1718e-03, 7.5017e-03, 1.2673e-03, 4.7525e-02, 1.7373e-03, - -5.1745e-03, -2.8621e-02, 3.4984e-02, -3.2622e-02, 1.0748e-02, - 1.2499e-02, -1.8788e-02, -8.6717e-03, 4.3620e-02, 2.8460e-02, - -6.8146e-03, -3.5824e-02, 9.2931e-03, 3.7893e-03, 2.4187e-02, - 1.3393e-02, -5.9393e-03, -9.9837e-03, -8.1019e-03, -2.1840e-02, - -3.8945e-02, 1.6736e-02, -4.7475e-02, 4.9770e-02, 3.4695e-02, - 1.8961e-02, 2.7416e-02, -1.3578e-02, -9.8595e-03, 2.2834e-03, - 2.4829e-02, -4.3998e-02, 3.2398e-02, -1.4200e-02, 2.4907e-02, - -2.2542e-02, -9.2765e-03, 2.0658e-03, -4.1246e-03, -1.8095e-02, - -1.2732e-02, -3.2090e-03, 1.3127e-02, -2.0888e-02, 1.4931e-02, - -4.0576e-02, 4.2877e-02, 7.9411e-05, -4.4377e-02, 3.2357e-03, - 1.6201e-02, 4.0387e-02, -1.9023e-02, 5.8033e-02, -3.3424e-02, - 2.9598e-03, -1.8526e-02, -2.2967e-02, 4.3449e-02, -1.2564e-02, - -9.3756e-03, -2.1745e-02, -2.7089e-02, -3.6791e-02, -5.2018e-02, - 2.4588e-02, 1.0037e-03, 3.9753e-02, 4.3534e-02, 2.6446e-02, - -1.1808e-02, 2.1426e-02, 7.5522e-03, 2.2847e-03, -2.7211e-02, - 4.1364e-02, -1.1281e-02, 1.6523e-03, -1.9913e-03 - ]]), - decimal=3 - ) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - loss = out.sum() - loss.backward() - optimizer.step() - out = model(x) - tt.assert_almost_equal( - out.cpu(), - torch.tensor([[ - -0.0247, -0.0172, -0.0633, -0.0154, -0.0156, -0.1156, -0.0793, -0.0184, - -0.0408, 0.0005, -0.0920, -0.0481, -0.0597, -0.0243, 0.0006, -0.1045, - -0.0476, -0.0030, -0.0230, -0.0869, -0.0149, -0.0412, -0.0753, -0.0640, - -0.1106, -0.0632, -0.0645, -0.0474, -0.0124, -0.0698, -0.0275, -0.0415, - -0.0916, -0.0957, -0.0851, -0.1296, -0.1049, -0.0196, -0.0823, -0.0380, - -0.1085, -0.0526, -0.0083, -0.1274, -0.0426, -0.0183, -0.0585, -0.0366, - -0.1111, -0.0074, -0.1238, -0.0324, -0.0166, -0.0719, -0.0285, -0.0427, - -0.1158, -0.0569, 0.0075, -0.0419, -0.0288, -0.1189, -0.0220, -0.0370, - 0.0040, 0.0228, -0.0958, -0.0258, -0.0276, -0.0405, -0.0958, -0.0201, - -0.0639, -0.0543, -0.0705, -0.0940, -0.0700, -0.0921, -0.0426, 0.0026, - -0.0556, -0.0439, -0.0386, -0.0957, -0.0915, -0.0679, -0.1272, -0.0754, - -0.0076, -0.1046, -0.0350, -0.0887, -0.0350, -0.0270, -0.1188, -0.0449, - 0.0020, -0.0406, 0.0011, -0.0842, -0.0422, -0.1280, -0.0205, 0.0002, - -0.0789, -0.0185, -0.0510, -0.1180, -0.0550, -0.0159, -0.0702, -0.0029, - -0.0891, -0.0253, -0.0485, -0.0128, 0.0010, -0.0870, -0.0230, -0.0233, - -0.0411, -0.0870, -0.0419, -0.0688, -0.0583, -0.0448, -0.0864, -0.0926, - -0.0758, -0.0540, 0.0058, -0.0843, -0.0365, -0.0608, -0.0787, -0.0938, - -0.0680, -0.0995, -0.0764, 0.0061, -0.0821, -0.0636, -0.0848, -0.0373, - -0.0285, -0.1086, -0.0464, -0.0228, -0.0464, -0.0279, -0.1053, -0.0224, - -0.1268, -0.0006, -0.0186, -0.0836, -0.0011, -0.0415, -0.1222, -0.0668, - -0.0015, -0.0535, -0.0071, -0.1202, -0.0257, -0.0503, 0.0004, 0.0099, - -0.1113, -0.0182, -0.0080, -0.0216, -0.0661, -0.0115, -0.0468, -0.0716, - -0.0404, -0.0950, -0.0681, -0.0933, -0.0699, -0.0154, -0.0853, -0.0414, - -0.0403, -0.0700, -0.0685, -0.0975, -0.0934, -0.1016, -0.0121, -0.1084, - -0.0391, -0.1006, -0.0441, -0.0024, -0.1232, -0.0159, 0.0012, -0.0480, - -0.0013, -0.0789, -0.0309, -0.1101 - ]]), - decimal=3 - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/all/presets/multiagent_atari_test.py b/all/presets/multiagent_atari_test.py index 8a213770..4f752da5 100644 --- a/all/presets/multiagent_atari_test.py +++ b/all/presets/multiagent_atari_test.py @@ -9,7 +9,7 @@ class TestMultiagentAtariPresets(unittest.TestCase): def setUp(self): - self.env = MultiagentAtariEnv('pong_v2', device='cpu') + self.env = MultiagentAtariEnv('pong_v3', device='cpu') self.env.reset() def tearDown(self): @@ -17,12 +17,11 @@ def tearDown(self): os.remove('test_preset.pt') def test_independent(self): - env = MultiagentAtariEnv('pong_v2', device='cpu') presets = { - agent_id: dqn.device('cpu').env(env.subenvs[agent_id]).build() - for agent_id in env.agents + agent_id: dqn.device('cpu').env(self.env.subenvs[agent_id]).build() + for agent_id in self.env.agents } - self.validate_preset(IndependentMultiagentPreset('independent', 'cpu', presets), env) + self.validate_preset(IndependentMultiagentPreset('independent', 'cpu', presets), self.env) def validate_preset(self, preset, env): # normal agent diff --git a/integration/multiagent_atari_test.py b/integration/multiagent_atari_test.py index ab48adfd..9e0b54cc 100644 --- a/integration/multiagent_atari_test.py +++ b/integration/multiagent_atari_test.py @@ -20,7 +20,7 @@ class TestMultiagentAtariPresets(unittest.TestCase): def test_independent(self): - env = MultiagentAtariEnv('pong_v2', max_cycles=1000, device=CPU) + env = MultiagentAtariEnv('pong_v3', max_cycles=1000, device=CPU) presets = { agent_id: dqn.device(CPU).env(env.subenvs[agent_id]).build() for agent_id in env.agents @@ -28,7 +28,7 @@ def test_independent(self): validate_multiagent(IndependentMultiagentPreset('independent', CPU, presets), env) def test_independent_cuda(self): - env = MultiagentAtariEnv('pong_v2', max_cycles=1000, device=CUDA) + env = MultiagentAtariEnv('pong_v3', max_cycles=1000, device=CUDA) presets = { agent_id: dqn.device(CUDA).env(env.subenvs[agent_id]).build() for agent_id in env.agents diff --git a/setup.py b/setup.py index ecfe3c20..0b81db8f 100644 --- a/setup.py +++ b/setup.py @@ -14,28 +14,35 @@ ], "pybullet": [ "pybullet>=3.2.2", + "gym>=0.10.0,<0.26.0", ], "ma-atari": [ - "PettingZoo[atari]~={}".format(PETTINGZOO_VERSION), - "supersuit~=3.3.5", + "PettingZoo[atari, accept-rom-license]~={}".format(PETTINGZOO_VERSION), + "supersuit~=3.9.1", ], "test": [ - "flake8>=3.8", # linter for pep8 compliance - "autopep8>=1.5", # automatically fixes some pep8 errors - "torch-testing>=0.0.2", # pytorch assertion library + "flake8>=3.8", # linter for pep8 compliance + "autopep8>=1.5", # automatically fixes some pep8 errors + "torch-testing>=0.0.2", # pytorch assertion library ], "docs": [ - "sphinx>=3.2.1", # documentation library + "sphinx>=3.2.1", # documentation library "sphinx-autobuild>=2020.9.1", # documentation live reload - "sphinx-rtd-theme>=0.5.0", # documentation theme - "sphinx-automodapi>=0.13", # autogenerate docs for modules + "sphinx-rtd-theme>=0.5.0", # documentation theme + "sphinx-automodapi>=0.13", # autogenerate docs for modules ], "comet": [ - "comet-ml>=3.28.3", # experiment tracking using Comet.ml - ] + "comet-ml>=3.28.3", # experiment tracking using Comet.ml + ], } -extras["all"] = extras["atari"] + extras["box2d"] + extras["pybullet"] + extras["ma-atari"] + extras["comet"] +extras["all"] = ( + extras["atari"] + + extras["box2d"] + + extras["pybullet"] + + extras["ma-atari"] + + extras["comet"] +) extras["dev"] = extras["all"] + extras["test"] + extras["docs"] + extras["comet"] setup( @@ -47,26 +54,26 @@ author="Chris Nota", author_email="cnota@cs.umass.edu", entry_points={ - 'console_scripts': [ - 'all-atari=scripts.atari:main', - 'all-classic=scripts.classic:main', - 'all-continuous=scripts.continuous:main', - 'all-plot=scripts.plot:main', - 'all-watch-atari=scripts.watch_atari:main', - 'all-watch-classic=scripts.watch_classic:main', - 'all-watch-continuous=scripts.watch_continuous:main', - 'all-benchmark-atari=benchmarks.atari40:main', - 'all-benchmark-pybullet=benchmarks.pybullet:main', + "console_scripts": [ + "all-atari=scripts.atari:main", + "all-classic=scripts.classic:main", + "all-continuous=scripts.continuous:main", + "all-plot=scripts.plot:main", + "all-watch-atari=scripts.watch_atari:main", + "all-watch-classic=scripts.watch_classic:main", + "all-watch-continuous=scripts.watch_continuous:main", + "all-benchmark-atari=benchmarks.atari40:main", + "all-benchmark-pybullet=benchmarks.pybullet:main", ], }, install_requires=[ - "gymnasium~={}".format(GYM_VERSION), # common environment interface - "numpy>=1.22.3", # math library - "matplotlib>=3.5.1", # plotting library - "opencv-python-headless>=4.0.0", # used by atari wrappers - "torch>=1.11.0", # core deep learning library - "tensorboard>=2.8.0", # logging and visualization - "cloudpickle>=2.0.0", # used to copy environments + "gymnasium~={}".format(GYM_VERSION), # common environment interface + "numpy>=1.22.3", # math library + "matplotlib>=3.5.1", # plotting library + "opencv-python-headless>=4.0.0", # used by atari wrappers + "torch>=1.11.0", # core deep learning library + "tensorboard>=2.8.0", # logging and visualization + "cloudpickle>=2.0.0", # used to copy environments ], - extras_require=extras + extras_require=extras, )