-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2000b3
commit 189d400
Showing
6 changed files
with
519 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Wolf-Sheep Predation Model using Petting Zoo | ||
|
||
This shows an implementation converting a mesa environment (Wolf-Sheep) into a parallel petting zoo environment and then using it with RLlib. This implementation suggests a way of converting mesa environments into petting zoo environments. | ||
|
||
TODO: Make it work for variable number of agents. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import mesa | ||
import numpy as np | ||
|
||
def move(self, action): | ||
# Get possible steps for the agent | ||
possible_steps = self.model.grid.get_neighborhood( | ||
self.pos, | ||
moore=True, | ||
include_center=False | ||
) | ||
action = int(action) # Convert action to integer | ||
|
||
# Uncomment this to test random baselines | ||
# action = np.random.randint(0, 4) | ||
|
||
# Move the agent based on the action | ||
if action == 0: | ||
new_position = (self.pos[0] + 1, self.pos[1]) | ||
elif action == 1: | ||
new_position = (self.pos[0] - 1, self.pos[1]) | ||
elif action == 2: | ||
new_position = (self.pos[0], self.pos[1] - 1) | ||
elif action == 3: | ||
new_position = (self.pos[0], self.pos[1] + 1) | ||
|
||
# Check if the new position is valid, then move the agent | ||
if new_position in possible_steps: | ||
self.model.grid.move_agent(self, new_position) | ||
|
||
class Sheep(mesa.Agent): | ||
""" | ||
A sheep that walks around, reproduces (asexually) and gets eaten. | ||
""" | ||
|
||
energy = None | ||
|
||
def __init__(self, unique_id, pos, model, energy=None): | ||
super().__init__(unique_id, model) | ||
self.energy = energy | ||
self.done = False | ||
self.pos = pos | ||
self.living = True | ||
self.time = 0 | ||
|
||
def step(self, action): | ||
""" | ||
A model step. Move, then eat grass and reproduce. | ||
""" | ||
if self.living: | ||
self.time += 1 | ||
move(self, action) | ||
|
||
if self.model.grass: | ||
# Reduce energy | ||
self.energy -= 1 | ||
|
||
# If there is grass available, eat it | ||
this_cell = self.model.grid.get_cell_list_contents([self.pos]) | ||
grass_patch = next(obj for obj in this_cell if isinstance(obj, GrassPatch)) | ||
if grass_patch.fully_grown: | ||
self.energy += self.model.sheep_gain_from_food | ||
grass_patch.fully_grown = False | ||
|
||
# Death | ||
if self.energy < 0: | ||
self.living = False | ||
|
||
if self.living and self.random.random() < self.model.sheep_reproduce: | ||
# Create a new sheep: | ||
if self.model.grass: | ||
self.energy /= 2 | ||
lamb = Sheep( | ||
self.model.next_id(), self.pos, self.model, self.energy | ||
) | ||
self.model.grid.place_agent(lamb, self.pos) | ||
self.model.schedule.add(lamb) | ||
|
||
|
||
class Wolf(mesa.Agent): | ||
""" | ||
A wolf that walks around, reproduces (asexually) and eats sheep. | ||
""" | ||
|
||
energy = None | ||
|
||
def __init__(self, unique_id, pos, model, energy=None): | ||
super().__init__(unique_id, model) | ||
self.energy = energy | ||
self.done = False | ||
self.pos = pos | ||
self.living = True | ||
self.time = 0 | ||
|
||
def step(self, action): | ||
if self.living: | ||
self.time += 1 | ||
move(self, action) | ||
self.energy -= 1 | ||
|
||
# If there are sheep present, eat one | ||
x, y = self.pos | ||
this_cell = self.model.grid.get_cell_list_contents([self.pos]) | ||
sheep = [obj for obj in this_cell if isinstance(obj, Sheep) and obj.living] | ||
if len(sheep) > 0: | ||
sheep_to_eat = self.random.choice(sheep) | ||
self.energy += self.model.wolf_gain_from_food | ||
|
||
sheep_to_eat.living = False | ||
|
||
# Death or reproduction | ||
if self.energy < 0: | ||
self.living = False | ||
else: | ||
if self.random.random() < self.model.wolf_reproduce: | ||
# Create a new wolf cub | ||
self.energy /= 2 | ||
cub = Wolf( | ||
self.model.next_id(), self.pos, self.model, self.energy | ||
) | ||
self.model.grid.place_agent(cub, cub.pos) | ||
self.model.schedule.add(cub) | ||
|
||
|
||
class GrassPatch(mesa.Agent): | ||
""" | ||
A patch of grass that grows at a fixed rate and it is eaten by sheep | ||
""" | ||
|
||
def __init__(self, unique_id, pos, model, fully_grown, countdown): | ||
""" | ||
Creates a new patch of grass | ||
Args: | ||
grown: (boolean) Whether the patch of grass is fully grown or not | ||
countdown: Time for the patch of grass to be fully grown again | ||
""" | ||
super().__init__(unique_id, model) | ||
self.fully_grown = fully_grown | ||
self.countdown = countdown | ||
self.pos = pos | ||
|
||
def step(self): | ||
if not self.fully_grown: | ||
if self.countdown <= 0: | ||
# Set as fully grown | ||
self.fully_grown = True | ||
self.countdown = self.model.grass_regrowth_time | ||
else: | ||
self.countdown -= 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import mesa | ||
import numpy as np | ||
import gymnasium as gym | ||
import functools | ||
from pettingzoo import ParallelEnv | ||
from agents import Wolf, Sheep, GrassPatch | ||
from utils_function import get_observation, remove_dead_agents, create_initial_agents | ||
from scheduler import RandomActivationByTypeFiltered | ||
|
||
class WolfSheep(mesa.Model, ParallelEnv): | ||
""" | ||
Wolf-Sheep Predation Model | ||
""" | ||
|
||
description = ( | ||
"A model for simulating wolf and sheep (predator-prey) ecosystem modelling." | ||
) | ||
|
||
def __init__( | ||
self, | ||
width=20, | ||
height=20, | ||
initial_sheep=100, | ||
initial_wolves=25, | ||
sheep_reproduce=0.0, | ||
wolf_reproduce=0.0, | ||
wolf_gain_from_food=20, | ||
grass=True, | ||
grass_regrowth_time=30, | ||
sheep_gain_from_food=4, | ||
): | ||
""" | ||
Create a new Wolf-Sheep model with the given parameters. | ||
""" | ||
super().__init__() | ||
# Set parameters | ||
self.width = width | ||
self.height = height | ||
self.initial_sheep = initial_sheep | ||
self.initial_wolves = initial_wolves | ||
self.sheep_reproduce = sheep_reproduce | ||
self.wolf_reproduce = wolf_reproduce | ||
self.wolf_gain_from_food = wolf_gain_from_food | ||
self.grass = grass | ||
self.grass_regrowth_time = grass_regrowth_time | ||
self.sheep_gain_from_food = sheep_gain_from_food | ||
|
||
self.schedule = RandomActivationByTypeFiltered(self) | ||
self.grid = mesa.space.MultiGrid(self.width, self.height, torus=True) | ||
self.datacollector = mesa.DataCollector( | ||
{ | ||
"Wolves": lambda m: m.schedule.get_type_count(Wolf), | ||
"Sheep": lambda m: m.schedule.get_type_count(Sheep), | ||
"Grass": lambda m: m.schedule.get_type_count( | ||
GrassPatch, lambda x: x.fully_grown | ||
), | ||
} | ||
) | ||
|
||
create_initial_agents(self) | ||
|
||
self.running = True | ||
self.datacollector.collect(self) | ||
self.time = 0 | ||
self.agents = [a.unique_id for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))] | ||
self.possible_agents = self.agents | ||
self.observation_spaces = {a: self.observation_space(a) for a in self.possible_agents} | ||
self.action_spaces = {a: self.action_space(a) for a in self.possible_agents} | ||
|
||
def step(self, action_dict): | ||
# Check if either wolves or sheep are extinct | ||
if self.schedule.get_type_count(Wolf) == 0 or self.schedule.get_type_count(Sheep) == 0: | ||
for agent in self.schedule.agents: | ||
if isinstance(agent, (Sheep, Wolf)): | ||
agent.living = False | ||
|
||
self.datacollector.collect(self) | ||
|
||
rewards = {a.unique_id: 0 for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
# Check for rewards and execute actions | ||
for agent in self.schedule.agents: | ||
if isinstance(agent, (Sheep, Wolf)): | ||
agent.step(action_dict[agent.unique_id]) | ||
if isinstance(agent, Sheep): | ||
rewards[agent.unique_id] += min(4, agent.energy - 4) | ||
else: | ||
rewards[agent.unique_id] += min(4, agent.energy/5 - 4) | ||
else: | ||
agent.step() | ||
|
||
for agent in self.schedule.agents: | ||
if isinstance(agent, (Sheep, Wolf)): | ||
if agent.unique_id not in rewards: | ||
rewards[agent.unique_id] = 0 | ||
if not agent.living: | ||
agent.done = True | ||
rewards[agent.unique_id] = min(0, -(25 - agent.time)) | ||
|
||
# Get observations | ||
obs = {a.unique_id: get_observation(self, a) for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
# Check if done | ||
done = {a.unique_id: a.done for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
self.time += 1 | ||
|
||
if self.time > 500: | ||
done = {a.unique_id: True for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
# Prepare info dictionary | ||
truncated = {a.unique_id: False for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
infos = {a.unique_id: {} for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
|
||
remove_dead_agents(self) | ||
|
||
return obs, rewards, done, truncated, infos | ||
|
||
def reset(self, seed=None, options=None): | ||
# Reset your environment here | ||
self.time = 0 | ||
self.schedule = RandomActivationByTypeFiltered(self) | ||
self.grid = mesa.space.MultiGrid(self.width, self.height, torus=True) | ||
self.current_id = 0 | ||
create_initial_agents(self) | ||
self.agents = [a.unique_id for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))] | ||
obs = {a.unique_id: get_observation(self, a) for a in self.schedule.agents if isinstance(a, (Sheep, Wolf))} | ||
infos = {a: {} for a in self.agents} | ||
return obs , infos | ||
|
||
|
||
def render(self): | ||
# Render the environment to the screen | ||
pass | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
def observation_space(self, agent): | ||
return gym.spaces.Dict({'grid': gym.spaces.Box(low=0, high=1, shape=(10, 10, 4), dtype=bool), | ||
'energy': gym.spaces.Box(low=-1, high=np.inf, shape=(1,), dtype=np.float32) | ||
}) | ||
|
||
@functools.lru_cache(maxsize=None) | ||
def action_space(self, agent): | ||
return gym.spaces.Discrete(4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from pettingzoo.test import parallel_api_test | ||
from environment import WolfSheep | ||
from ray.tune.registry import register_env | ||
from ray import tune, air | ||
import os | ||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.policy.policy import PolicySpec | ||
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec | ||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec | ||
from agents import Sheep | ||
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv | ||
import ray | ||
|
||
def env_creator(): | ||
return WolfSheep() | ||
|
||
if __name__ == "__main__": | ||
# Uncomment to debug | ||
# ray.init(local_mode=True) | ||
|
||
env = WolfSheep() | ||
parallel_api_test(env, num_cycles=1_000_000) | ||
|
||
# Register the environment under an rllib name | ||
register_env('WorldSheepModel', lambda config: ParallelPettingZooEnv(env_creator())) | ||
|
||
# Define the configuration for the PPO algorithm | ||
config = ( | ||
PPOConfig() | ||
.environment("WorldSheepModel") | ||
.framework("torch") | ||
.multi_agent( | ||
policies={ | ||
"policy_sheep": PolicySpec( | ||
config=PPOConfig.overrides(framework_str="torch") | ||
), | ||
"policy_wolf": PolicySpec( | ||
config=PPOConfig.overrides(framework_str="torch") | ||
) | ||
}, | ||
policy_mapping_fn=lambda agent_id, *args, **kwargs: "policy_sheep" if isinstance(agent_id, Sheep) else "policy_wolf", | ||
policies_to_train=["policy_sheep", "policy_wolf"], | ||
) | ||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "1")), num_cpus_for_local_worker=2, num_cpus_per_learner_worker=2) | ||
.rl_module( | ||
rl_module_spec=MultiAgentRLModuleSpec( | ||
module_specs={ | ||
"policy_sheep": SingleAgentRLModuleSpec(), | ||
"policy_wolf": SingleAgentRLModuleSpec() | ||
} | ||
), | ||
) | ||
) | ||
|
||
stop = { | ||
"training_iteration": 20, | ||
"episode_reward_mean": 1000, | ||
"timesteps_total": 1000000, | ||
} | ||
|
||
results = tune.Tuner( | ||
"PPO", | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig(stop=stop, verbose=1, checkpoint_config=air.CheckpointConfig(checkpoint_frequency=1)), | ||
).fit() |
Oops, something went wrong.