Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/scripts-folder #286

Merged
merged 9 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions all/environments/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@


class AtariEnvironment(Environment):
def __init__(self, name, device='cpu'):
def __init__(self, name, device='cpu', **gym_make_kwargs):

# construct the environment
env = gymnasium.make(name + "NoFrameskip-v4")
env = gymnasium.make(name + "NoFrameskip-v4", **gym_make_kwargs)

# apply a subset of wrappers
env = NoopResetEnv(env, noop_max=30)
Expand Down
5 changes: 3 additions & 2 deletions all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ class GymEnvironment(Environment):
name (str, optional): the name of the environment
device (str, optional): the device on which tensors will be stored
legacy_gym (str, optional): If true, calls gym.make() instead of gymnasium.make()
**gym_make_kwargs: kwargs passed to gymnasium.make(id, **gym_make_kwargs)
'''

def __init__(self, id, device=torch.device('cpu'), name=None, legacy_gym=False):
def __init__(self, id, device=torch.device('cpu'), name=None, legacy_gym=False, **gym_make_kwargs):
if legacy_gym:
import gym
self._gym = gym
else:
self._gym = gymnasium
self._env = self._gym.make(id)
self._env = self._gym.make(id, **gym_make_kwargs)
self._id = id
self._name = name if name else id
self._state = None
Expand Down
4 changes: 2 additions & 2 deletions all/environments/multiagent_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def step(self, action):
def seed(self, seed):
self._env.seed(seed)

def render(self, mode='human'):
return self._env.render(mode=mode)
def render(self, **kwargs):
return self._env.render(**kwargs)

def close(self):
self._env.close()
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _log_test_episode(self, episode, returns):
def _log_test(self, returns):
if not self._quiet:
mean = np.mean(returns)
sem = np.var(returns) / np.sqrt(len(returns))
sem = np.std(returns) / np.sqrt(len(returns))
print('test returns (mean ± sem): {} ± {}'.format(mean, sem))
self._logger.add_summary('returns-test', np.mean(returns), np.std(returns))

Expand Down
6 changes: 3 additions & 3 deletions all/experiments/multiagent_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(
self._logger = self._make_logger(logdir, self._name, env.name, verbose, logger)
self._agent = preset.agent(logger=self._logger, train_steps=train_steps)
self._env = env
self._episode = 0
self._frame = 0
self._episode = 1
self._frame = 1
self._logdir = logdir
self._preset = preset
self._quiet = quiet
Expand Down Expand Up @@ -171,7 +171,7 @@ def _log_test(self, returns):
for agent, agent_returns in returns.items():
if not self._quiet:
mean = np.mean(agent_returns)
sem = np.variance(agent_returns) / np.sqrt(len(agent_returns))
sem = np.std(agent_returns) / np.sqrt(len(agent_returns))
print('{} test returns (mean ± sem): {} ± {}'.format(agent, mean, sem))
self._logger.add_summary('{}/returns-test'.format(agent), np.mean(agent_returns), np.std(agent_returns))

Expand Down
13 changes: 8 additions & 5 deletions all/experiments/multiagent_env_experiment_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import unittest
import numpy as np
import torch
Expand All @@ -16,9 +17,10 @@ def _make_logger(self, logdir, agent_name, env_name, verbose, logger):

class TestMultiagentEnvExperiment(unittest.TestCase):
def setUp(self):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
self.env = MultiagentAtariEnv('space_invaders_v2', device='cpu')
self.env = MultiagentAtariEnv('space_invaders_v2', device='cpu', seed=0)
self.env.reset(seed=0)
self.experiment = None

Expand All @@ -34,10 +36,11 @@ def test_writes_training_returns(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf'))
experiment.train(episodes=3)
self.maxDiff = None
self.assertEqual(experiment._logger.data, {
'eval/first_0/returns/frame': {'values': [705.0, 490.0, 230.0, 435.0], 'steps': [808, 1580, 2120, 3300]},
'eval/second_0/returns/frame': {'values': [115.0, 525.0, 415.0, 665.0], 'steps': [808, 1580, 2120, 3300]}
})
# could not get the exact numbers to be reproducible across enviornments :(
self.assertEqual(len(experiment._logger.data['eval/first_0/returns/frame']['values']), 3)
self.assertEqual(len(experiment._logger.data['eval/first_0/returns/frame']['steps']), 3)
self.assertEqual(len(experiment._logger.data['eval/second_0/returns/frame']['values']), 3)
self.assertEqual(len(experiment._logger.data['eval/second_0/returns/frame']['steps']), 3)

def test_writes_test_returns(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf'))
Expand Down
1 change: 1 addition & 0 deletions all/experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_experiment(
verbose=verbose,
logger=logger
)
experiment.save()
experiment.train(frames=frames)
experiment.save()
experiment.test(episodes=test_episodes)
Expand Down
22 changes: 14 additions & 8 deletions all/experiments/single_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def close(self):

class MockExperiment(SingleEnvExperiment):
def _make_logger(self, logdir, agent_name, env_name, verbose, logger):
self._logger = MockLogger(self, agent_name + '_' + env_name, verbose)
self._logger = MockLogger(self, agent_name + "_" + env_name, verbose)
return self._logger


class TestSingleEnvExperiment(unittest.TestCase):
def setUp(self):
np.random.seed(0)
torch.manual_seed(0)
self.env = GymEnvironment('CartPole-v0')
self.env = GymEnvironment("CartPole-v0")
self.env.reset(seed=0)
self.experiment = None

Expand All @@ -66,15 +66,17 @@ def test_adds_default_name(self):
self.assertEqual(experiment._logger.label, "dqn_CartPole-v0")

def test_adds_custom_name(self):
experiment = MockExperiment(self.make_preset(), self.env, name='dqn', quiet=True)
experiment = MockExperiment(
self.make_preset(), self.env, name="dqn", quiet=True
)
self.assertEqual(experiment._logger.label, "dqn_CartPole-v0")

def test_writes_training_returns_eps(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True)
experiment.train(episodes=3)
np.testing.assert_equal(
experiment._logger.data["eval/returns/episode"]["values"],
np.array([22., 17., 28.]),
np.array([22.0, 17.0, 28.0]),
)
np.testing.assert_equal(
experiment._logger.data["eval/returns/episode"]["steps"],
Expand All @@ -95,21 +97,25 @@ def test_writes_test_returns(self):
np.testing.assert_approx_equal(
np.array(experiment._logger.data["summary/returns-test/std"]["values"]),
np.array([expected_std]),
significant=4
significant=4,
)
np.testing.assert_equal(
experiment._logger.data["summary/returns-test/mean"]["steps"],
np.array([93]),
)

def test_writes_loss(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, verbose=True)
experiment = MockExperiment(
self.make_preset(), self.env, quiet=True, verbose=True
)
self.assertTrue(experiment._logger.verbose)
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, verbose=False)
experiment = MockExperiment(
self.make_preset(), self.env, quiet=True, verbose=False
)
self.assertFalse(experiment._logger.verbose)

def make_preset(self):
return dqn.device('cpu').env(self.env).build()
return dqn.device("cpu").env(self.env).build()


if __name__ == "__main__":
Expand Down
22 changes: 9 additions & 13 deletions all/experiments/watch.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
import os
import time
import torch
import gymnasium
from all.agents import Agent
import sys


def watch(agent, env, fps=60):
def watch(agent, env, fps=60, n_episodes=sys.maxsize):
action = None
returns = 0
# have to call this before initial reset for pybullet envs
env.render(mode="human")
env.reset()

while True:
for _ in range(n_episodes):
env.render()
action = agent.act(env.state)
env.step(action)
returns += env.state.reward

time.sleep(1 / fps)
if env.state.done:
print('returns:', returns)
env.reset()
returns = 0
else:
env.step(action)
env.render()

time.sleep(1 / fps)


def load_and_watch(filename, env, fps=60):
def load_and_watch(filename, env, fps=60, n_episodes=sys.maxsize):
agent = torch.load(filename).test_agent()
watch(agent, env, fps=fps)
watch(agent, env, fps=fps, n_episodes=n_episodes)
28 changes: 28 additions & 0 deletions all/experiments/watch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .watch import load_and_watch
import unittest
from unittest import mock
import torch
from all.environments import GymEnvironment


class MockAgent():
def act(self):
# sample from cartpole action space
return torch.randint(0, 2, [])


class MockPreset():
def __init__(self, filename):
self.filename = filename

def test_agent(self):
return MockAgent


class WatchTest(unittest.TestCase):
@mock.patch('torch.load', lambda filename: MockPreset(filename))
@mock.patch('time.sleep', mock.MagicMock())
def test_load_and_watch(self):
env = mock.MagicMock(GymEnvironment("CartPole-v0", render_mode="rgb_array"))
load_and_watch("file.name", env, n_episodes=3)
self.assertEqual(env.reset.call_count, 4)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 6 additions & 4 deletions scripts/multiagent_atari.py → all/scripts/multiagent_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def main():

env = MultiagentAtariEnv(args.env, device=args.device)

assert len(env.agents) == len(args.agents), f"Must specify {len(env.agents)} agents for this environment."

presets = {
agent_id: getattr(atari, agent_type)
.hyperparameters(replay_buffer_size=args.replay_buffer_size)
.device(args.device)
.env(env.subenvs[agent_id])
.build()
.hyperparameters(replay_buffer_size=args.replay_buffer_size)
.device(args.device)
.env(env.subenvs[agent_id])
.build()
for agent_id, agent_type in zip(env.agents, args.agents)
}

Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion scripts/watch_atari.py → all/scripts/watch_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main():
help="Playback speed",
)
args = parser.parse_args()
env = AtariEnvironment(args.env, device=args.device)
env = AtariEnvironment(args.env, device=args.device, render_mode="human")
load_and_watch(args.filename, env, fps=args.fps)


Expand Down
2 changes: 1 addition & 1 deletion scripts/watch_classic.py → all/scripts/watch_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
help="Playback speed",
)
args = parser.parse_args()
env = GymEnvironment(args.env, device=args.device)
env = GymEnvironment(args.env, device=args.device, render_mode="human")
load_and_watch(args.filename, env, fps=args.fps)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def main():
args = parser.parse_args()

if args.env in ENVS:
env = GymEnvironment(args.env, device=args.device)
env = GymEnvironment(args.env, device=args.device, render_mode="human")
elif 'BulletEnv' in args.env or args.env in PybulletEnvironment.short_names:
env = PybulletEnvironment(args.env, device=args.device)
env = PybulletEnvironment(args.env, device=args.device, render_mode="human")
else:
env = GymEnvironment(args.env, device=args.device)
env = GymEnvironment(args.env, device=args.device, render_mode="human")

load_and_watch(args.filename, env, fps=args.fps)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
action="store_true", default=False, help="Reload the model from disk after every episode"
)
args = parser.parse_args()
env = MultiagentAtariEnv(args.env, device=args.device)
env = MultiagentAtariEnv(args.env, device=args.device, render_mode="human")
watch(env, args.filename, args.fps, args.reload)


Expand Down
18 changes: 9 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@
author_email="[email protected]",
entry_points={
"console_scripts": [
"all-atari=scripts.atari:main",
"all-classic=scripts.classic:main",
"all-continuous=scripts.continuous:main",
"all-plot=scripts.plot:main",
"all-watch-atari=scripts.watch_atari:main",
"all-watch-classic=scripts.watch_classic:main",
"all-watch-continuous=scripts.watch_continuous:main",
"all-benchmark-atari=benchmarks.atari40:main",
"all-benchmark-pybullet=benchmarks.pybullet:main",
"all-atari=all.scripts.atari:main",
"all-classic=all.scripts.classic:main",
"all-continuous=all.scripts.continuous:main",
"all-multiagent-atari=all.scripts.multiagent_atari:main",
"all-plot=all.scripts.plot:main",
"all-watch-atari=all.scripts.watch_atari:main",
"all-watch-classic=all.scripts.watch_classic:main",
"all-watch-continuous=all.scripts.watch_continuous:main",
"all-watch-multiagent-atari=all.scripts.watch_multiagent_atari:main",
],
},
install_requires=[
Expand Down
Loading