Skip to content

Commit

Permalink
breakable-bottles observation space correction (#93)
Browse files Browse the repository at this point in the history
Signed-off-by: Scott Johnson <[email protected]>
Co-authored-by: Scott Johnson <[email protected]>
  • Loading branch information
scott-j-johnson and Scott Johnson authored Apr 15, 2024
1 parent a30e5c0 commit 503dbf0
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mo_gymnasium/envs/breakable_bottles/breakable_bottles.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ class BreakableBottles(Env, EzPickle):
The observation space is a dictionary with 4 keys:
- location: the current location of the agent
- bottles_carrying: the number of bottles the agent is currently carrying (0, 1 or 2)
- bottles_delivered: the number of bottles the agent has delivered (0 or 1)
- bottles_delivered: the number of bottles the agent has delivered (0, 1 or 2)
- bottles_dropped: for each location, a boolean flag indicating if that location currently contains a bottle
Note that this observation space is different from that listed in the paper above. In the paper, bottles_delivered's possible values are listed as (0 or 1),
rather than (0, 1 or 2). This is because the paper did not take the terminal state, in which 2 bottles have been delivered, into account when calculating
the observation space. As such, the observation space of this implementation is larger than specified in the paper, having 360 possible states instead of 240.
## Reward Space
The reward space has 3 dimensions:
- time penalty: -1 for each time step
Expand Down Expand Up @@ -96,11 +100,11 @@ def __init__(
{
"location": Discrete(self.size),
"bottles_carrying": Discrete(3),
"bottles_delivered": Discrete(2),
"bottles_delivered": Discrete(3),
"bottles_dropped": MultiBinary(self.size - 2),
}
)
self.num_observations = 240
self.num_observations = 360

self.action_space = Discrete(3) # LEFT, RIGHT, PICKUP
self.num_actions = 3
Expand Down Expand Up @@ -220,7 +224,7 @@ def get_obs_idx(self, obs):
*[[bd > 0] for bd in obs["bottles_dropped"]],
]
)
return np.ravel_multi_index(multi_index, tuple([self.size, 3, 2, *([2] * (self.size - 2))]))
return np.ravel_multi_index(multi_index, tuple([self.size, 3, 3, *([2] * (self.size - 2))]))

def _get_obs(self):
return {
Expand Down

0 comments on commit 503dbf0

Please sign in to comment.