Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi agent reinforcement training demo #797

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ process
- Added glb models for pedestrians and motorcycles
- Added `--allow-offset-map` option for `scl scenario build` to prevent auto-shifting of Sumo road networks
- Added options in DoneCreteria to trigger ego agent to be done based on other agent's done situation
- Added a multi-agent adversarial training demo, game of tag under examples folder
### Changed
- Refactored SMARTS class to not inherit from Panda3D's ShowBase; it's aggregated instead. See issue #597.
### Fixed
Expand Down
49 changes: 49 additions & 0 deletions examples/game_of_tag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Game of Tag
This directory contains a a multi-agent adversarial training demo. In the demo, there is a predator vehicle and a prey vehicle.
The predator vehicle's goal is to catch the prey, and the prey vehicle's goal is to avoid getting caught.

## Run training
python examples/game_of_tag/game_of_tag.py examples/game_of_tag/scenarios/game_of_tag_demo_map/

## Run checkpoint
python examples/game_of_tag/run_checkpoint.py examples/game_of_tag/scenarios/game_of_tag_demo_map/

## Setup:
### Rewards
The formula for reward is 0.5/(distance-COLLIDE_DISTANCE)^2 and capped at 10

- COLLIDE_DISTANCE is the observed distance when two vehicle collides. Since the position of two vehicle is at the center, the distance when collesion happens is not exactly 0.

### Common Reward:
Off road: -10

#### Prey:
Collision with predator: -10
Distance to predator(d): 0.5/(d-COLLIDE_DISTANCE)^2
#### Predator:
Collision with predator: -10
Distance to predator(d): 0.5/(d-COLLIDE_DISTANCE)^2

### Action:
Speed selection in m/s: [0, 3, 6, 9]

Lane change selection relative to current lane: [-1, 0, 1]

## Output a model:
Currently Rllib does not have implementation for exporting a pytorch model.

Replace `export_model`'s implementation in `ray/rllib/policy/torch_policy.py` to the following:
```
torch.save(self.model.state_dict(),f"{export_dir}/model.pt")
```
Then follow the steps in game_of_tag.py to export the model.

## Possible next steps
- Increase the number of agents to 2 predators and 2 prey.
This requires modelling the reward to still be a zero sum game. The complication can be understood from
how to model the distance reward between 2 predators and 1 prey. If the reward is only from nearest predator
to nearest prey, the sum of predator and prey rewards will no longer be 0 because 2 predators will be getting full
reward from 1 prey but the prey will only get full reward from 1 predator. This will require the predators to know about each
other or the prey to know about other prey, and the prey to know about multiple predators.
- Add an attribute in observations to display whether the ego car is in front of the target vehicle or behind it, this may
help to let ego vehicle know whether it should slow down or speed up
282 changes: 282 additions & 0 deletions examples/game_of_tag/game_of_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""Let's play tag!

A predator-prey multi-agent example built on top of RLlib to facilitate further
developments on multi-agent support for HiWay (including design, performance,
research, and scaling).

The predator and prey use separate policies. A predator "catches" its prey when
it collides into the other vehicle. There can be multiple predators and
multiple prey in a map. Social vehicles act as obstacles where both the
predator and prey must avoid them.
"""
import argparse
import os
import random
import multiprocessing
import ray


import numpy as np
from typing import List
from ray import tune
from ray.rllib.utils import try_import_tf
from ray.rllib.models import ModelCatalog
from ray.tune import Stopper
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.tune.schedulers import PopulationBasedTraining
from ray.rllib.agents.ppo import PPOTrainer
from pathlib import Path

from smarts.env.rllib_hiway_env import RLlibHiWayEnv
from smarts.core.agent import AgentSpec, Agent
from smarts.core.controllers import ActionSpaceType
from smarts.core.agent_interface import (
AgentInterface,
AgentType,
DoneCriteria,
AgentsAliveDoneCriteria,
AgentsListAlive,
)
from smarts.core.utils.file import copy_tree


from examples.game_of_tag.tag_adapters import *
from examples.game_of_tag.model import CustomFCModel


# Add custom metrics to your tensorboard using these callbacks
# see: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))


def on_episode_step(info):
episode = info["episode"]
single_agent_id = list(episode._agent_to_last_obs)[0]
obs = episode.last_raw_obs_for(single_agent_id)


def on_episode_end(info):
episode = info["episode"]


def explore(config):
# ensure we collect enough timesteps to do sgd
if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
config["train_batch_size"] = config["sgd_minibatch_size"] * 2
# ensure we run at least one sgd iter
if config["num_sgd_iter"] < 1:
config["num_sgd_iter"] = 1
return config


PREDATOR_POLICY = "predator_policy"
PREY_POLICY = "prey_policy"


def policy_mapper(agent_id):
if agent_id in PREDATOR_IDS:
return PREDATOR_POLICY
elif agent_id in PREY_IDS:
return PREY_POLICY


class TimeStopper(Stopper):
def __init__(self):
self._start = time.time()
# Currently will see obvious tag behaviour in 6 hours
self._deadline = 48 * 60 * 60 # train for 48 hours

def __call__(self, trial_id, result):
return False

def stop_all(self):
return time.time() - self._start > self._deadline


tf = try_import_tf()

ModelCatalog.register_custom_model("CustomFCModel", CustomFCModel)

rllib_agents = {}

shared_interface = AgentInterface(
max_episode_steps=1500,
neighborhood_vehicles=True,
waypoints=True,
action=ActionSpaceType.LaneWithContinuousSpeed,
)
shared_interface.done_criteria = DoneCriteria(
off_route=False,
wrong_way=False,
collision=True,
agents_alive=AgentsAliveDoneCriteria(
agent_lists_alive=[
AgentsListAlive(agents_list=PREY_IDS, minimum_agents_alive_in_list=1),
AgentsListAlive(agents_list=PREDATOR_IDS, minimum_agents_alive_in_list=1),
]
),
)

for agent_id in PREDATOR_IDS:
rllib_agents[agent_id] = {
"agent_spec": AgentSpec(
interface=shared_interface,
agent_builder=lambda: TagModelAgent(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "model"),
OBSERVATION_SPACE,
),
observation_adapter=observation_adapter,
reward_adapter=predator_reward_adapter,
action_adapter=action_adapter,
),
"observation_space": OBSERVATION_SPACE,
"action_space": ACTION_SPACE,
}

for agent_id in PREY_IDS:
rllib_agents[agent_id] = {
"agent_spec": AgentSpec(
interface=shared_interface,
agent_builder=lambda: TagModelAgent(
os.path.join(os.path.dirname(os.path.realpath(__file__)), "model"),
OBSERVATION_SPACE,
),
observation_adapter=observation_adapter,
reward_adapter=prey_reward_adapter,
action_adapter=action_adapter,
),
"observation_space": OBSERVATION_SPACE,
"action_space": ACTION_SPACE,
}


def build_tune_config(scenario, headless=True, sumo_headless=False):
rllib_policies = {
policy_mapper(agent_id): (
None,
rllib_agent["observation_space"],
rllib_agent["action_space"],
{"model": {"custom_model": "CustomFCModel"}},
)
for agent_id, rllib_agent in rllib_agents.items()
}

tune_config = {
"env": RLlibHiWayEnv,
"framework": "torch",
"log_level": "WARN",
"num_workers": 3,
"explore": True,
"horizon": 10000,
"env_config": {
"seed": 42,
"sim_name": "game_of_tag_works?",
"scenarios": [os.path.abspath(scenario)],
"headless": headless,
"sumo_headless": sumo_headless,
"agent_specs": {
agent_id: rllib_agent["agent_spec"]
for agent_id, rllib_agent in rllib_agents.items()
},
},
"multiagent": {
"policies": rllib_policies,
"policies_to_train": [PREDATOR_POLICY, PREY_POLICY],
"policy_mapping_fn": policy_mapper,
},
"callbacks": {
"on_episode_start": on_episode_start,
"on_episode_step": on_episode_step,
"on_episode_end": on_episode_end,
},
}
return tune_config


def main(args):
pbt = PopulationBasedTraining(
time_attr="time_total_s",
metric="episode_reward_mean",
mode="max",
perturbation_interval=300,
resample_probability=0.25,
# Specifies the mutations of these hyperparams
hyperparam_mutations={
"lambda": lambda: random.uniform(0.9, 1.0),
"clip_param": lambda: random.uniform(0.01, 0.5),
"kl_coeff": lambda: 0.3,
"lr": [1e-3],
"sgd_minibatch_size": lambda: 128,
"train_batch_size": lambda: 4000,
"num_sgd_iter": lambda: 30,
},
custom_explore_fn=explore,
)
local_dir = os.path.expanduser(args.result_dir)

tune_config = build_tune_config(args.scenario)

tune.run(
PPOTrainer, # Rllib supports using PPO in multi-agent setting
name="lets_play_tag",
stop=TimeStopper(),
# XXX: Every X iterations perform a _ray actor_ checkpoint (this is
# different than _exporting_ a TF/PT checkpoint).
checkpoint_freq=5,
checkpoint_at_end=True,
# XXX: Beware, resuming after changing tune params will not pick up
# the new arguments as they are stored alongside the checkpoint.
resume=args.resume_training,
# restore="path_to_training_checkpoint/checkpoint_x/checkpoint-x",
local_dir=local_dir,
reuse_actors=True,
max_failures=0,
export_formats=["model", "checkpoint"],
config=tune_config,
scheduler=pbt,
)

# # To output a model
# # 1: comment out tune.run and uncomment the following code
# # 2: replace checkpoint path to training checkpoint path
# # 3: inject code in rllib according to README.md and run
# checkpoint_path = os.path.join(
# os.path.dirname(os.path.realpath(__file__)), "models/checkpoint_360/checkpoint-360"
# )
# ray.init(num_cpus=2)
# training_agent = PPOTrainer(env=RLlibHiWayEnv,config=tune_config)
# training_agent.restore(checkpoint_path)
# prefix = "model.ckpt"
# model_dir = os.path.join(
# os.path.dirname(os.path.realpath(__file__)), "models/predator_model"
# )
# training_agent.export_policy_model(model_dir, PREDATOR_POLICY)
# model_dir = os.path.join(
# os.path.dirname(os.path.realpath(__file__)), "models/prey_model"
# )
# training_agent.export_policy_model(model_dir, PREY_POLICY)


if __name__ == "__main__":
parser = argparse.ArgumentParser("rllib-example")
parser.add_argument(
"scenario",
type=str,
help="Scenario to run (see scenarios/ for some samples you can use)",
)
parser.add_argument(
"--resume_training",
default=False,
action="store_true",
help="Resume the last trained example",
)
parser.add_argument(
"--result_dir",
type=str,
default="~/ray_results",
help="Directory containing results (and checkpointing)",
)
args = parser.parse_args()
main(args)
37 changes: 37 additions & 0 deletions examples/game_of_tag/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch, gym
from torch import nn
from torch.distributions.normal import Normal
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet


class CustomFCModel(TorchModelV2, nn.Module):
"""Example of interpreting repeated observations."""

def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config,
name: str,
):
super(CustomFCModel, self).__init__(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name,
)
nn.Module.__init__(self)

self.model = TorchFCNet(
obs_space, action_space, num_outputs, model_config, name
)

def forward(self, input_dict, state, seq_lens):

return self.model.forward(input_dict, state, seq_lens)

def value_function(self):
return self.model.value_function()
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added examples/game_of_tag/models/prey_model/model.pt
Binary file not shown.
Loading