diff --git a/mo_gymnasium/envs/fishwood/fishwood.py b/mo_gymnasium/envs/fishwood/fishwood.py index fad02d77..7aa1242b 100644 --- a/mo_gymnasium/envs/fishwood/fishwood.py +++ b/mo_gymnasium/envs/fishwood/fishwood.py @@ -42,8 +42,8 @@ class FishWood(gym.Env, EzPickle): """ metadata = {"render_modes": ["human"]} - FISH = 0 - WOOD = 1 + FISH = np.array([0], dtype=np.int32) + WOOD = np.array([1], dtype=np.int32) MAX_TS = 200 def __init__(self, render_mode: Optional[str] = None, fishproba=0.1, woodproba=0.9): @@ -55,17 +55,17 @@ def __init__(self, render_mode: Optional[str] = None, fishproba=0.1, woodproba=0 self.action_space = spaces.Discrete(2) # 2 actions, go fish and go wood # 2 states, fishing and in the woods - self.observation_space = spaces.Discrete(2) + self.observation_space = spaces.Box(low=0, high=1, shape=(1,), dtype=np.int32) # 2 objectives, amount of fish and amount of wood self.reward_space = spaces.Box(low=np.array([0, 0]), high=np.array([1.0, 1.0]), dtype=np.float32) self.reward_dim = 2 - self._state = self.WOOD + self._state = self.WOOD.copy() def reset(self, seed=None, **kwargs): super().reset(seed=seed) - self._state = self.WOOD + self._state = self.WOOD.copy() self._timestep = 0 if self.render_mode == "human": self.render() @@ -89,7 +89,7 @@ def step(self, action): rewards[self.FISH] = 1.0 # Execute the action - self._state = action + self._state = np.array([action], dtype=np.int32) self._timestep += 1 if self.render_mode == "human":