Skip to content

Commit

Permalink
make dummy obs at reset and use it
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Jun 14, 2023
1 parent 222c795 commit 7a35c25
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions nmmo/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from typing import Any, Dict, List, Callable
from collections import defaultdict
from copy import copy
from ordered_set import OrderedSet

import gym
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(self,
self.config = config
self.realm = realm.Realm(config)
self.obs = None
self._dummy_obs = None

self.possible_agents = list(range(1, config.PLAYER_N + 1))
self._dead_agents = set()
Expand Down Expand Up @@ -157,6 +159,7 @@ def reset(self, map_id=None, seed=None, options=None,
if isinstance(ent.agent, Scripted):
self.scripted_agents.add(eid)

self._dummy_obs = self._make_dummy_obs()
self.obs = self._compute_observations()
self._gamestate_generator = GameStateGenerator(self.realm, self.config)

Expand Down Expand Up @@ -360,22 +363,26 @@ def _compute_scripted_agent_actions(self, actions: Dict[int, Dict[str, Dict[str,

return actions

def _compute_observations(self):
'''Create an Observation object for each agent in self.agents'''
obs = {}
market = Item.Query.for_sale(self.realm.datastore) # the same for all agents

# dummy obs
def _make_dummy_obs(self):
dummy_tiles = np.zeros((0, len(Tile.State.attr_name_to_col)))
dummy_entities = np.zeros((0, len(Entity.State.attr_name_to_col)))
dummy_inventory = np.zeros((0, len(Item.State.attr_name_to_col)))
dummy_market = np.zeros((0, len(Item.State.attr_name_to_col)))
return Observation(self.config, self.realm.tick, 0,
dummy_tiles, dummy_entities, dummy_inventory, dummy_market)

def _compute_observations(self):
'''Create an Observation object for each agent in self.agents'''
obs = {}
market = Item.Query.for_sale(self.realm.datastore) # the same for all agents

for agent_id in self.agents:
if agent_id not in self.realm.players:
# return dummy obs for the agents in dead_this_tick
obs[agent_id] = Observation(self.config, self.realm.tick, agent_id,
dummy_tiles, dummy_entities, dummy_inventory, dummy_market)
dummy_obs = copy(self._dummy_obs)
dummy_obs.current_tick = self.realm.tick
dummy_obs.agent_id = agent_id
obs[agent_id] = dummy_obs
else:
agent = self.realm.players.get(agent_id)
agent_r = agent.row.val
Expand Down

0 comments on commit 7a35c25

Please sign in to comment.