-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] PettingZoo dict action spaces #2692
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2692
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
@matteobettini no need for test here? |
We can add it, but pettingzoo itself does not provide an env with this action space type. What I am using to test is this env. I can include it, but it seems like a lot of lines to test this simple thing so didn't want to clutter wdyt? class Environment(pettingzoo.ParallelEnv):
"""
Sample environment where two agents have two rocks in front of them.
Each agent predicts a probability distribution of which rock to hit and what strength to hit it with.
First agent to break a rock could win or something like that but doesn't really matter for demo purposes.
"""
agent_strengths = {
"agent_0": 100,
"agent_1": 80,
}
rock_hps = [700, 400]
def __init__(self):
pettingzoo.ParallelEnv.__init__(self)
agents = [
"agent_0",
"agent_1",
]
self.agents = agents
self.possible_agents = agents
def reset(self, **kwargs):
observations = {
"agent_0": {"rocks": [1.0, 1.0]},
"agent_1": {"rocks": [1.0, 1.0]},
}
return observations, {"agent_0": {}, "agent_1": {}}
def action_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return spaces.Dict(
[
("target", spaces.Box(0, 1, [2])), # probability of hitting each rock
(
"strength",
spaces.Box(0, 1, [2]),
), # percentage strength to hit it with
(
"healing",
spaces.Box(0, 100, [2]),
), # each agent can heal rock for an absolute value < 100
]
)
def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space:
return spaces.Dict(
[
("rocks", spaces.Box(0, 1, [2])), # rock hp as a percentage
]
)
def step(
self, actions: dict[AgentID, ActionType]
) -> tuple[
dict[AgentID, ObsType],
dict[AgentID, float],
dict[AgentID, bool],
dict[AgentID, bool],
dict[AgentID, dict],
]:
damage_dealt = {}
for agent, agent_actions in actions.items():
target = agent_actions["target"]
target = np.argmax(target)
damage = agent_actions["strength"] * self.agent_strengths[agent]
damage_dealt[agent] = (target, damage)
observations = {
"agent_0": {
"rocks": [1.0, 1.0],
},
"agent_1": {
"rocks": [1.0, 1.0],
},
}
rewards = {
"agent_0": 0.0,
"agent_1": 0.0,
}
terminations = {
"agent_0": False,
"agent_1": False,
}
truncations = {
"agent_0": False,
"agent_1": False,
}
info = {}
return observations, rewards, terminations, truncations, info |
If that solves a problem but there's no test we should at least have a comment in the code that explains what's going on (as a soft safekeeping) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fixes #2680