Skip to content

Commit

Permalink
Make new gymnasium-based 2.2 candidate working with Ray RLlib
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 8, 2023
1 parent 45ab787 commit 33fc1e8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 8 additions & 10 deletions diambra/arena/ray_rllib/make_ray_env.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import diambra.arena
import logging
import gym
import gymnasium as gym
from ray.rllib.env.env_context import EnvContext
from copy import deepcopy
import pickle

class DiambraArena(gym.Env):

def __init__(self, config: EnvContext):

self.logger = logging.getLogger(__name__)

# If to load environment spaces from a file
Expand All @@ -31,7 +29,6 @@ def __init__(self, config: EnvContext):
self.env_spaces_file_name = config["env_spaces_file_name"]

if self.load_spaces_from_file is False:

if "is_rollout" not in config.keys():
message = "Environment initialized without a preprocessed config file."
message += " Make sure to call \"preprocess_ray_config\" before initializing Ray RL Algorithms."
Expand All @@ -40,7 +37,7 @@ def __init__(self, config: EnvContext):
self.game_id = config["game_id"]
self.settings = config["settings"] if "settings" in config.keys() else {}
self.wrappers_settings = config["wrappers_settings"] if "wrappers_settings" in config.keys() else {}
self.seed = config["seed"] if "seed" in config.keys() else 0
self.render_mode = config["render_mode"] if "render_mode" in config.keys() else None

num_rollout_workers = config["num_workers"]
num_eval_workers = config["evaluation_num_workers"]
Expand Down Expand Up @@ -68,8 +65,7 @@ def __init__(self, config: EnvContext):

self.logger.debug("Rank: {}".format(self.rank))

self.env = diambra.arena.make(self.game_id, self.settings, self.wrappers_settings,
seed=self.seed + self.rank, rank=self.rank)
self.env = diambra.arena.make(self.game_id, self.settings, self.wrappers_settings, render_mode=self.render_mode, rank=self.rank)

env_spaces_dict = {}
env_spaces_dict["action_space"] = self.env.action_space
Expand All @@ -93,8 +89,11 @@ def __init__(self, config: EnvContext):
self.action_space = env_spaces_dict["action_space"]
self.observation_space = env_spaces_dict["observation_space"]

def reset(self):
return self.env.reset()
def reset(self, seed=None, options=None):
if self.load_spaces_from_file is True:
return self.observation_space.sample(), {}
else:
return self.env.reset(seed=seed, options=options)

def step(self, action):
return self.env.step(action)
Expand All @@ -103,7 +102,6 @@ def render(self):
return self.env.render()

def preprocess_ray_config(config):

logger = logging.getLogger(__name__)

num_envs_required = 0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
'tests': ['pytest', 'pytest-mock', 'testresources'],
'stable-baselines': ['stable-baselines==2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"],
'stable-baselines3': ['stable-baselines3[extra]==2.1.0', "pyyaml"],
'ray-rllib': ['ray[rllib]==2.0.0', 'tensorflow<=2.10.0', 'torch<=1.12.1', "pyyaml"],
'ray-rllib': ['ray[rllib]==2.6.3', 'tensorflow', 'torch', "pyyaml"],
}

# NOTE Package data is inside MANIFEST.In
Expand Down

0 comments on commit 33fc1e8

Please sign in to comment.