Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Aug 30, 2024
1 parent 2b57f0e commit 7d37933
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 26 deletions.
34 changes: 9 additions & 25 deletions metaworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,11 @@ def init_each_env(
env_cls: type[SawyerXYZEnv], name: str, seed: int | None
) -> gym.Env:
env = env_cls()
if seed:
env.seed(seed)
env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length)
if terminate_on_success:
env = AutoTerminateOnSuccessWrapper(env)
env = AutoTerminateOnSuccessWrapper(env)
env.toggle_terminate_on_success(terminate_on_success)
env = gym.wrappers.RecordEpisodeStatistics(env)
if use_one_hot:
assert env_id is not None, "Need to pass env_id through constructor"
Expand All @@ -349,29 +351,9 @@ def init_each_env(
env = RandomTaskSelectWrapper(env, tasks, seed=seed)
return env

if "MT1-" in name:
name = name.replace("MT1-", "")
benchmark = MT1(name, seed=seed)
return init_each_env(
env_cls=benchmark.train_classes[name], name=name, seed=seed
)
elif "ML1-" in name:
benchmark = ML1(
name.replace("ML1-train-" if "train" in name else "ML1-test-", ""),
seed=seed,
) # type: ignore
if "train" in name:
return init_each_env(
env_cls=benchmark.train_classes[name.replace("ML1-train-", "")],
name=name + "-train",
seed=seed,
) # type: ignore
elif "test" in name:
return init_each_env(
env_cls=benchmark.test_classes[name.replace("ML1-test-", "")],
name=name + "-test",
seed=seed,
)
name = name.replace("MT1-", "")
benchmark = MT1(name, seed=seed)
return init_each_env(env_cls=benchmark.train_classes[name], name=name, seed=seed)


make_single_mt = partial(_make_single_env, terminate_on_success=False)
Expand Down Expand Up @@ -405,6 +387,8 @@ def _make_single_ml(

def make_env(env_cls: type[SawyerXYZEnv], tasks: list) -> gym.Env:
env = env_cls()
if seed:
env.seed(seed)
env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length)
env = AutoTerminateOnSuccessWrapper(env)
env.toggle_terminate_on_success(terminate_on_success)
Expand Down
2 changes: 1 addition & 1 deletion metaworld/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _set_random_task(self):
def __init__(
self,
env: Env,
tasks: list[Task],
tasks: List[Task],
sample_tasks_on_reset: bool = True,
seed: int | None = None,
):
Expand Down

0 comments on commit 7d37933

Please sign in to comment.