Skip to content

Commit

Permalink
Merge pull request #27 from CarperAI/daveey-git-pr-787-5807
Browse files Browse the repository at this point in the history
Fix dones / rewards
  • Loading branch information
daveey authored Feb 28, 2023
2 parents acdb38c + 2a35f26 commit b136538
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions nmmo/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self,
self.obs = None

self.possible_agents = list(range(1, config.PLAYER_N + 1))
self._dead_agents = set()
self.scripted_agents = OrderedSet()

# pylint: disable=method-cache-max-size-none
Expand Down Expand Up @@ -130,6 +131,7 @@ def reset(self, map_id=None, seed=None, options=None):

self._init_random(seed)
self.realm.reset(map_id)
self._dead_agents = set()

# check if there are scripted agents
for eid, ent in self.realm.players.items():
Expand Down Expand Up @@ -246,13 +248,18 @@ def step(self, actions: Dict[int, Dict[str, Dict[str, Any]]]):

# Execute actions
self.realm.step(actions)
dones = {eid: eid not in self.realm.players for eid in self.possible_agents}

dones = {}
for eid in self.possible_agents:
if eid not in self.realm.players and eid not in self._dead_agents:
self._dead_agents.add(eid)
dones[eid] = True

# Store the observations, since actions reference them
self.obs = self._compute_observations()
gym_obs = {a: o.to_gym() for a,o in self.obs.items()}

rewards, infos = self._compute_rewards(self.obs.keys())
rewards, infos = self._compute_rewards(self.obs.keys(), dones)

return gym_obs, rewards, dones, infos

Expand Down Expand Up @@ -394,7 +401,7 @@ def _compute_observations(self):

return obs

def _compute_rewards(self, agents: List[AgentID] = None):
def _compute_rewards(self, agents: List[AgentID], dones: Dict[AgentID, bool]):
'''Computes the reward for the specified agent
Override this method to create custom reward functions. You have full
Expand All @@ -410,15 +417,12 @@ def _compute_rewards(self, agents: List[AgentID] = None):
entity identified by ent_id.
'''
infos = {}
rewards = {}
rewards = { eid: -1 for eid in dones }

for agent_id in agents:
infos[agent_id] = {}
agent = self.realm.players.get(agent_id)

if agent is None:
rewards[agent_id] = -1
continue
assert agent is not None, f'Agent {agent_id} not found'

infos[agent_id] = {'population': agent.population}

Expand Down

0 comments on commit b136538

Please sign in to comment.