diff --git a/nmmo/core/env.py b/nmmo/core/env.py index 4c65708d..f678a7fe 100644 --- a/nmmo/core/env.py +++ b/nmmo/core/env.py @@ -2,6 +2,7 @@ import random from typing import Any, Dict, List, Callable from collections import defaultdict +from copy import copy from ordered_set import OrderedSet import gym @@ -33,6 +34,7 @@ def __init__(self, self.config = config self.realm = realm.Realm(config) self.obs = None + self._dummy_obs = None self.possible_agents = list(range(1, config.PLAYER_N + 1)) self._dead_agents = set() @@ -157,6 +159,7 @@ def reset(self, map_id=None, seed=None, options=None, if isinstance(ent.agent, Scripted): self.scripted_agents.add(eid) + self._dummy_obs = self._make_dummy_obs() self.obs = self._compute_observations() self._gamestate_generator = GameStateGenerator(self.realm, self.config) @@ -360,22 +363,26 @@ def _compute_scripted_agent_actions(self, actions: Dict[int, Dict[str, Dict[str, return actions - def _compute_observations(self): - '''Create an Observation object for each agent in self.agents''' - obs = {} - market = Item.Query.for_sale(self.realm.datastore) # the same for all agents - - # dummy obs + def _make_dummy_obs(self): dummy_tiles = np.zeros((0, len(Tile.State.attr_name_to_col))) dummy_entities = np.zeros((0, len(Entity.State.attr_name_to_col))) dummy_inventory = np.zeros((0, len(Item.State.attr_name_to_col))) dummy_market = np.zeros((0, len(Item.State.attr_name_to_col))) + return Observation(self.config, self.realm.tick, 0, + dummy_tiles, dummy_entities, dummy_inventory, dummy_market) + + def _compute_observations(self): + '''Create an Observation object for each agent in self.agents''' + obs = {} + market = Item.Query.for_sale(self.realm.datastore) # the same for all agents for agent_id in self.agents: if agent_id not in self.realm.players: # return dummy obs for the agents in dead_this_tick - obs[agent_id] = Observation(self.config, self.realm.tick, agent_id, - dummy_tiles, dummy_entities, dummy_inventory, dummy_market) + dummy_obs = copy(self._dummy_obs) + dummy_obs.current_tick = self.realm.tick + dummy_obs.agent_id = agent_id + obs[agent_id] = dummy_obs else: agent = self.realm.players.get(agent_id) agent_r = agent.row.val