Skip to content

Commit

Permalink
Merge pull request #63 from CarperAI/task-rev
Browse files Browse the repository at this point in the history
Refining task api, etc
  • Loading branch information
jsuarez5341 authored Jun 7, 2023
2 parents 3567f38 + 79a66bd commit 7386079
Show file tree
Hide file tree
Showing 24 changed files with 1,713 additions and 1,033 deletions.
112 changes: 32 additions & 80 deletions nmmo/core/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import random
import copy
from typing import Any, Dict, List, Optional, Union, Tuple
from typing import Any, Dict, List, Callable
from collections import defaultdict
from ordered_set import OrderedSet

import gym
Expand All @@ -16,8 +16,7 @@
from nmmo.entity.entity import Entity
from nmmo.systems.item import Item
from nmmo.task.game_state import GameStateGenerator
from nmmo.task.task_api import Task
from nmmo.task.scenario import default_task
from nmmo.task import task_api
from scripted.baselines import Scripted

class Env(ParallelEnv):
Expand All @@ -41,15 +40,7 @@ def __init__(self,

self._gamestate_generator = GameStateGenerator(self.realm, self.config)
self.game_state = None
# Default task: rewards 1 each turn agent is alive
self.tasks: List[Tuple[Task,float]] = None
self._task_encoding = None
self._task_embedding_size = -1
t = default_task(self.possible_agents)
self.change_task(t,
embedding_size=self._task_embedding_size,
task_encoding=self._task_encoding,
reset=False)
self.tasks = task_api.nmmo_default_task(self.possible_agents)

# pylint: disable=method-cache-max-size-none
@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -88,12 +79,6 @@ def box(rows, cols):
if self.config.PROVIDE_ACTION_TARGETS:
obs_space['ActionTargets'] = self.action_space(None)

if self._task_encoding:
obs_space['Task'] = gym.spaces.Box(
low=-2**20, high=2**20,
shape=(self._task_embedding_size,),
dtype=np.float32)

return gym.spaces.Dict(obs_space)

def _init_random(self, seed):
Expand Down Expand Up @@ -131,38 +116,18 @@ def action_space(self, agent):
############################################################################
# Core API

def change_task(self,
new_tasks: List[Union[Tuple[Task, float], Task]],
task_encoding: Optional[Dict[int, np.ndarray]] = None,
embedding_size: int=16,
reset: bool=True,
map_id=None,
seed=None,
options=None):
""" Changes the task given to each agent
Args:
new_task: The task to complete and calculate rewards
task_encoding: A mapping from eid to encoded task
embedding_size: The size of each embedding
reset: Resets the environment
"""
self._tasks = [t if isinstance(t, Tuple) else (t,1) for t in new_tasks]
self._task_encoding = task_encoding
self._task_embedding_size = embedding_size
if reset:
self.reset(map_id=map_id, seed=seed, options=options)

# TODO: This doesn't conform to the PettingZoo API
# pylint: disable=arguments-renamed
def reset(self, map_id=None, seed=None, options=None):
def reset(self, map_id=None, seed=None, options=None,
make_task_fn: Callable=None):
'''OpenAI Gym API reset function
Loads a new game map and returns initial observations
Args:
idx: Map index to load. Selects a random map by default
map_id: Map index to load. Selects a random map by default
seed: random seed to use
make_task_fn: A function to make tasks
Returns:
observations, as documented by _compute_observations()
Expand All @@ -186,16 +151,16 @@ def reset(self, map_id=None, seed=None, options=None):
if isinstance(ent.agent, Scripted):
self.scripted_agents.add(eid)

self.tasks = copy.deepcopy(self._tasks)
self.obs = self._compute_observations()
self._gamestate_generator = GameStateGenerator(self.realm, self.config)

gym_obs = {}
for a, o in self.obs.items():
gym_obs[a] = o.to_gym()
if self._task_encoding:
gym_obs[a]['Task'] = self._encode_goal().get(a,np.zeros(self._task_embedding_size))
return gym_obs
if make_task_fn is not None:
self.tasks = make_task_fn()
else:
for task in self.tasks:
task.reset()

return {a: o.to_gym() for a,o in self.obs.items()}

def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
'''Simulates one game tick or timestep
Expand Down Expand Up @@ -308,11 +273,7 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):

# Store the observations, since actions reference them
self.obs = self._compute_observations()
gym_obs = {}
for a, o in self.obs.items():
gym_obs[a] = o.to_gym()
if self._task_encoding:
gym_obs[a]['Task'] = self._encode_goal()[a]
gym_obs = {a: o.to_gym() for a,o in self.obs.items()}

rewards, infos = self._compute_rewards(self.obs.keys(), dones)

Expand All @@ -321,8 +282,6 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
def _validate_actions(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):
'''Deserialize action arg values and validate actions
For now, it does a basic validation (e.g., value is not none).
TODO(kywch): add sophisticated validation like use/sell/give on the same item
'''
validated_actions = {}

Expand Down Expand Up @@ -423,9 +382,6 @@ def _compute_observations(self):
inventory, market)
return obs

def _encode_goal(self):
return self._task_encoding

def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]):
'''Computes the reward for the specified agent
Expand All @@ -442,27 +398,23 @@ def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]):
entity identified by ent_id.
'''
# Initialization
self.game_state = self._gamestate_generator.generate(self.realm, self.obs)
infos = {}
for eid in agents:
infos[eid] = {}
infos[eid]['task'] = {}
rewards = {eid: 0 for eid in agents}
infos = {agent_id: {'task': {}} for agent_id in agents}
rewards = defaultdict(int)
agents = set(agents)
reward_cache = {}

# Compute Rewards and infos
for task, weight in self.tasks:
task_rewards, task_infos = task.compute_rewards(self.game_state)
for eid, reward in task_rewards.items():
# Rewards, weighted
rewards[eid] = rewards.get(eid,0) + reward * weight
# Infos
for eid, info in task_infos.items():
if eid in infos:
infos[eid]['task'] = {**infos[eid]['task'], **info}

# Remove rewards for dead agents (?)
for eid in dones:
rewards[eid] = 0
self.game_state = self._gamestate_generator.generate(self.realm, self.obs)
for task in self.tasks:
if task in reward_cache:
task_rewards, task_infos = reward_cache[task]
else:
task_rewards, task_infos = task.compute_rewards(self.game_state)
reward_cache[task] = (task_rewards, task_infos)
for agent_id, reward in task_rewards.items():
if agent_id in agents and agent_id not in dones:
rewards[agent_id] = rewards.get(agent_id,0) + reward
infos[agent_id]['task'][task.name] = task_infos[agent_id] # progress

return rewards, infos

Expand Down
6 changes: 3 additions & 3 deletions nmmo/core/realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ def reset(self, map_id: int = None):
self.log_helper.reset()
self.event_log.reset()

if self._replay_helper is not None:
self._replay_helper.reset()

self.map.reset(map_id or np.random.randint(self.config.MAP_N) + 1)

# EntityState and ItemState tables must be empty after players/npcs.reset()
Expand All @@ -104,6 +101,9 @@ def reset(self, map_id: int = None):
Item.INSTANCE_ID = 0
self.items = {}

if self._replay_helper is not None:
self._replay_helper.reset()

def packet(self):
"""Client packet"""
return {
Expand Down
1 change: 0 additions & 1 deletion nmmo/lib/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,4 @@ def get_team_spawn_positions(config, num_teams):
idx = int(len(side)*(i+1)/(teams_per_sides + 1))
team_spawn_positions.append(side[idx])

np.random.shuffle(team_spawn_positions)
return team_spawn_positions
14 changes: 13 additions & 1 deletion nmmo/lib/team_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List


class TeamHelper():
def __init__(self, teams: Dict[int, List[int]]):
self.teams = teams
Expand All @@ -23,3 +22,16 @@ def agent_id(self, team_id: int, position: int) -> int:

def is_agent_in_team(self, agent_id:int , team_id: int) -> bool:
return agent_id in self.teams[team_id]

def get_target_agent(self, team_id: int, target: str):
if target == 'left_team':
return self.teams[(team_id+1) % self.num_teams]
if target == 'left_team_leader':
return self.teams[(team_id+1) % self.num_teams][0]
if target == 'right_team':
return self.teams[(team_id-1) % self.num_teams]
if target == 'right_team_leader':
return self.teams[(team_id-1) % self.num_teams][0]
if target == 'my_team_leader':
return self.teams[team_id][0]
return None
1 change: 1 addition & 0 deletions nmmo/render/replay_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def reset(self):
self.packets = []
self.map = None
self._i = 0
self.update() # to capture the initial packet

def __len__(self):
return len(self.packets)
Expand Down
9 changes: 6 additions & 3 deletions nmmo/systems/skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ def update(self):
if not config.RESOURCE_SYSTEM_ENABLED:
return

if config.IMMORTAL:
return

depletion = config.RESOURCE_DEPLETION_RATE
water = self.entity.resources.water
water.decrement(depletion)

if self.config.IMMORTAL:
return

if not self.harvest_adjacent(material.Water, deplete=False):
return

Expand All @@ -288,6 +288,9 @@ def update(self):
if not config.RESOURCE_SYSTEM_ENABLED:
return

if config.IMMORTAL:
return

depletion = config.RESOURCE_DEPLETION_RATE
food = self.entity.resources.food
food.decrement(depletion)
Expand Down
3 changes: 1 addition & 2 deletions nmmo/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .game_state import *
from .predicate_api import *
from .task_api import *
from .scenario import *
from .team_helper import *
Loading

0 comments on commit 7386079

Please sign in to comment.