diff --git a/metaworld/__init__.py b/metaworld/__init__.py index 19b9729c..a3255ebe 100644 --- a/metaworld/__init__.py +++ b/metaworld/__init__.py @@ -6,7 +6,7 @@ import pickle from collections import OrderedDict from functools import partial -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Literal import gymnasium as gym # type: ignore import numpy as np @@ -14,9 +14,7 @@ # noqa: D104 from gymnasium.envs.registration import register -from numpy.typing import NDArray -import metaworld # type: ignore import metaworld.env_dict as _env_dict from metaworld.env_dict import ( ALL_V3_ENVIRONMENTS, @@ -324,485 +322,373 @@ def __init__(self, seed=None): ) -def _make_single_env( - name: str, +class CustomML(Benchmark): + """ + A custom meta RL benchmark. + Provide the desired train and test env names during initialisation. + """ + + def __init__(self, train_envs: list[str], test_envs: list[str], seed=None): + if len(set(train_envs).intersection(set(test_envs))) != 0: + raise ValueError("The test tasks cannot contain any of the train tasks.") + + self._train_classes = _env_dict._get_env_dict(train_envs) + train_kwargs = _env_dict._get_args_kwargs( + ALL_V3_ENVIRONMENTS, self._train_classes + ) + + self._test_classes = _env_dict._get_env_dict(test_envs) + test_kwargs = _env_dict._get_args_kwargs( + ALL_V3_ENVIRONMENTS, self._test_classes + ) + + self._train_tasks = _make_tasks( + self._train_classes, train_kwargs, _ML_OVERRIDE, seed=seed + ) + self._test_tasks = _make_tasks( + self._test_classes, test_kwargs, _ML_OVERRIDE, seed=seed + ) + + +def _init_each_env( + env_cls: type[SawyerXYZEnv], + tasks: list[Task], seed: int | None = None, max_episode_steps: int | None = None, + terminate_on_success: bool = False, use_one_hot: bool = False, env_id: int | None = None, num_tasks: int | None = None, - terminate_on_success: bool = False, + task_select: Literal["random", "pseudorandom"] = "random", ) -> gym.Env: - def init_each_env( - env_cls: type[SawyerXYZEnv], name: str, seed: int | None - ) -> gym.Env: - env = env_cls() - if seed: - env.seed(seed) - env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) - env = AutoTerminateOnSuccessWrapper(env) - env.toggle_terminate_on_success(terminate_on_success) - env = gym.wrappers.RecordEpisodeStatistics(env) - if use_one_hot: - assert env_id is not None, "Need to pass env_id through constructor" - assert num_tasks is not None, "Need to pass num_tasks through constructor" - env = OneHotWrapper(env, env_id, num_tasks) - tasks = [task for task in benchmark.train_tasks if task.env_name in name] - env = RandomTaskSelectWrapper(env, tasks, seed=seed) - return env - - name = name.replace("MT1-", "") - benchmark = MT1(name, seed=seed) - return init_each_env(env_cls=benchmark.train_classes[name], name=name, seed=seed) - - -make_single_mt = partial(_make_single_env, terminate_on_success=False) - - -def _make_single_ml( + env: gym.Env = env_cls() + if seed is not None: + env.seed(seed) # type: ignore + env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) # type: ignore + env = AutoTerminateOnSuccessWrapper(env) + env.toggle_terminate_on_success(terminate_on_success) + env = gym.wrappers.RecordEpisodeStatistics(env) + if use_one_hot: + assert env_id is not None, "Need to pass env_id through constructor" + assert num_tasks is not None, "Need to pass num_tasks through constructor" + env = OneHotWrapper(env, env_id, num_tasks) + if task_select != "random": + env = PseudoRandomTaskSelectWrapper(env, tasks) + else: + env = RandomTaskSelectWrapper(env, tasks) + return env + + +def make_mt_envs( name: str, - seed: int, - tasks_per_env: int, - env_num: int, + seed: int | None = None, max_episode_steps: int | None = None, - split: str = "train", + use_one_hot: bool = False, + env_id: int | None = None, + num_tasks: int | None = None, terminate_on_success: bool = False, - task_select: str = "random", + vector_strategy: Literal["sync", "async"] = "sync", + task_select: Literal["random", "pseudorandom"] = "random", +) -> gym.Env | gym.vector.VectorEnv: + benchmark: Benchmark + if name in ALL_V3_ENVIRONMENTS.keys(): + benchmark = MT1(name, seed=seed) + tasks = [task for task in benchmark.train_tasks] + return _init_each_env( + env_cls=benchmark.train_classes[name], + tasks=tasks, + seed=seed, + max_episode_steps=max_episode_steps, + use_one_hot=use_one_hot, + env_id=env_id, + num_tasks=num_tasks or 1, + terminate_on_success=terminate_on_success, + ) + elif name == "MT10" or name == "MT50": + benchmark = globals()[name](seed=seed) + vectorizer: type[gym.vector.VectorEnv] = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + default_num_tasks = 10 if name == "MT10" else 50 + return vectorizer( # type: ignore + [ + partial( + _init_each_env, + env_cls=env_cls, + tasks=[ + task for task in benchmark.train_tasks if task.env_name == name + ], + seed=seed, + max_episode_steps=max_episode_steps, + use_one_hot=use_one_hot, + env_id=env_id, + num_tasks=num_tasks or default_num_tasks, + terminate_on_success=terminate_on_success, + task_select=task_select, + ) + for env_id, (name, env_cls) in enumerate( + benchmark.train_classes.items() + ) + ] + ) + else: + raise ValueError( + "Invalid MT env name. Must either be a valid Metaworld task name (e.g. 'reach-v3'), 'MT10' or 'MT50'." + ) + + +def _make_ml_envs_inner( + benchmark: Benchmark, + meta_batch_size: int, + seed: int | None = None, total_tasks_per_cls: int | None = None, + max_episode_steps: int | None = None, + split: Literal["train", "test"] = "train", + terminate_on_success: bool = False, + task_select: Literal["random", "pseudorandom"] = "pseudorandom", + vector_strategy: Literal["sync", "async"] = "sync", ): - benchmark = ML1( - name.replace("ML1-train-" if "train" in name else "ML1-test-", ""), + all_classes = ( + benchmark.train_classes if split == "train" else benchmark.test_classes + ) + all_tasks = benchmark.train_tasks if split == "train" else benchmark.test_tasks + assert ( + meta_batch_size % len(all_classes) == 0 + ), "meta_batch_size must be divisible by envs_per_task" + tasks_per_env = meta_batch_size // len(all_classes) + + env_tuples = [] + for env_name, env_cls in all_classes.items(): + tasks = [task for task in all_tasks if task.env_name == env_name] + if total_tasks_per_cls is not None: + tasks = tasks[:total_tasks_per_cls] + subenv_tasks = [tasks[i::tasks_per_env] for i in range(0, tasks_per_env)] + for tasks_for_subenv in subenv_tasks: + assert ( + len(tasks_for_subenv) == len(tasks) // tasks_per_env + ), f"Invalid division of subtasks, expected {len(tasks) // tasks_per_env} got {len(tasks_for_subenv)}" + env_tuples.append((env_cls, tasks_for_subenv)) + + vectorizer: type[gym.vector.VectorEnv] = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + return vectorizer( # type: ignore + [ + partial( + _init_each_env, + env_cls=env_cls, + tasks=tasks, + seed=seed, + max_episode_steps=max_episode_steps, + terminate_on_success=terminate_on_success, + task_select=task_select, + ) + for env_cls, tasks in env_tuples + ] + ) + + +def make_ml_envs( + name: str, + seed: int | None = None, + meta_batch_size: int = 20, + total_tasks_per_cls: int | None = None, + max_episode_steps: int | None = None, + split: Literal["train", "test"] = "train", + terminate_on_success: bool = False, + task_select: Literal["random", "pseudorandom"] = "pseudorandom", + vector_strategy: Literal["sync", "async"] = "sync", +) -> gym.vector.VectorEnv: + benchmark: Benchmark + if name in ALL_V3_ENVIRONMENTS.keys(): + benchmark = ML1(name, seed=seed) + elif name == "ML10" or name == "ML45": + benchmark = globals()[name](seed=seed) + else: + raise ValueError( + "Invalid ML env name. Must either be a valid Metaworld task name (e.g. 'reach-v3'), 'ML10' or 'ML45'." + ) + return _make_ml_envs_inner( + benchmark, + meta_batch_size=meta_batch_size, seed=seed, - ) # type: ignore - cls = ( - benchmark.train_classes[name.replace("ML1-train-", "")] - if split == "train" - else benchmark.test_classes[name.replace("ML1-test-", "")] + total_tasks_per_cls=total_tasks_per_cls, + max_episode_steps=max_episode_steps, + split=split, + terminate_on_success=terminate_on_success, + task_select=task_select, + vector_strategy=vector_strategy, ) - tasks = benchmark.train_tasks if split == "train" else benchmark.test_tasks - - if total_tasks_per_cls is not None: - tasks = tasks[:total_tasks_per_cls] - tasks = [tasks[i::tasks_per_env] for i in range(0, tasks_per_env)][env_num] - - def make_env(env_cls: type[SawyerXYZEnv], tasks: list) -> gym.Env: - env = env_cls() - if seed: - env.seed(seed) - env = gym.wrappers.TimeLimit(env, max_episode_steps or env.max_path_length) - env = AutoTerminateOnSuccessWrapper(env) - env.toggle_terminate_on_success(terminate_on_success) - env = gym.wrappers.RecordEpisodeStatistics(env) - if task_select != "random": - env = PseudoRandomTaskSelectWrapper(env, tasks) - else: - env = RandomTaskSelectWrapper(env, tasks) - return env - - return make_env(cls, tasks) - - -make_single_ml_train = partial( - _make_single_ml, + + +make_ml_envs_train = partial( + make_ml_envs, terminate_on_success=False, task_select="pseudorandom", split="train", ) -make_single_ml_test = partial( - _make_single_ml, terminate_on_success=True, task_select="pseudorandom", split="test" +make_ml_envs_test = partial( + make_ml_envs, terminate_on_success=True, task_select="pseudorandom", split="test" ) -def register_mw_envs(): - for name in ALL_V3_ENVIRONMENTS: - kwargs = {"name": "MT1-" + name} - register( - id=f"Meta-World/{name}", - entry_point="metaworld:make_single_mt", - kwargs=kwargs, +def register_mw_envs() -> None: + def _mt_bench_vector_entry_point( + mt_bench: str, + vector_strategy: Literal["sync", "async"], + seed=None, + use_one_hot=False, + num_envs=None, + *args, + **lamb_kwargs, + ): + return make_mt_envs( # type: ignore + mt_bench, + seed=seed, + use_one_hot=use_one_hot, + vector_strategy=vector_strategy, # type: ignore + *args, + **lamb_kwargs, ) - kwargs = {"name": "ML1-train-" + name} - register( - id=f"Meta-World/ML1-train-{name}", - entry_point="metaworld:make_single_ml_train", - kwargs=kwargs, + + def _ml_bench_vector_entry_point( + ml_bench: str, + split: str, + vector_strategy: Literal["sync", "async"], + seed: int | None = None, + meta_batch_size: int = 20, + num_envs=None, + *args, + **lamb_kwargs, + ): + env_generator = make_ml_envs_train if split == "train" else make_ml_envs_test + return env_generator( + ml_bench, + seed=seed, + meta_batch_size=meta_batch_size, + vector_strategy=vector_strategy, + *args, + **lamb_kwargs, ) - kwargs = {"name": "ML1-test-" + name} + + for name in ALL_V3_ENVIRONMENTS.keys(): + kwargs = {"name": name} register( - id=f"Meta-World/ML1-test-{name}", - entry_point="metaworld:make_single_ml_test", + id=f"Meta-World/{name}", + entry_point="metaworld:make_mt_envs", kwargs=kwargs, ) + for vector_strategy in ["sync", "async"]: + for split in ["train", "test"]: + register( + id=f"Meta-World/ML1-{split}-{name}-{vector_strategy}", + vector_entry_point=partial( + _ml_bench_vector_entry_point, name, split, vector_strategy + ), + kwargs={}, + ) for name_hid in ALL_V3_ENVIRONMENTS_GOAL_HIDDEN: - kwargs = {} register( id=f"Meta-World/{name_hid}", - entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[name_hid]( + entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_HIDDEN[name_hid]( # type: ignore seed=seed ), - kwargs=kwargs, + kwargs={}, ) for name_obs in ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE: - kwargs = {} register( id=f"Meta-World/{name_obs}", - entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[name_obs]( + entry_point=lambda seed: ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE[name_obs]( # type: ignore seed=seed ), - kwargs=kwargs, + kwargs={}, ) - kwargs = {} - register( - id="Meta-World/MT10-sync", - vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=10, - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, + for mt_bench in ["MT10", "MT50"]: + for vector_strategy in ["sync", "async"]: + register( + id=f"Meta-World/{mt_bench}-{vector_strategy}", + vector_entry_point=partial( + _mt_bench_vector_entry_point, mt_bench, vector_strategy + ), + kwargs={}, + ) + + for ml_bench in ["ML10", "ML45"]: + for vector_strategy in ["sync", "async"]: + for split in ["train", "test"]: + register( + id=f"Meta-World/{ml_bench}-{split}-{vector_strategy}", + vector_entry_point=partial( + _ml_bench_vector_entry_point, ml_bench, split, vector_strategy + ), ) - for idx, env_name in enumerate(list(_env_dict.MT10_V3.keys())) - ], - ), - kwargs=kwargs, - ) - register( - id="Meta-World/MT50-sync", - vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=50, - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(list(_env_dict.MT50_V3.keys())) - ] - ), - kwargs=kwargs, - ) - register( - id="Meta-World/MT50-async", - vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=50, - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(list(_env_dict.MT50_V3.keys())) - ] - ), - kwargs=kwargs, - ) + for vector_strategy in ["sync", "async"]: + + def _custom_mt_vector_entry_point( + vector_strategy: str, + envs_list: list[str], + seed=None, + use_one_hot: bool = False, + num_envs=None, + *args, + **lamb_kwargs, + ): + vectorizer: type[gym.vector.VectorEnv] = getattr( + gym.vector, f"{vector_strategy.capitalize()}VectorEnv" + ) + return ( + vectorizer( # type: ignore + [ + partial( # type: ignore + make_mt_envs, + env_name, + num_tasks=len(envs_list), + env_id=idx, + seed=None if not seed else seed + idx, + use_one_hot=use_one_hot, + *args, + **lamb_kwargs, + ) + for idx, env_name in enumerate(envs_list) + ] + ), + ) - register( - id="Meta-World/MT10-async", - vector_entry_point=lambda seed=None, use_one_hot=False, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=10, - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(list(_env_dict.MT10_V3.keys())) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML10-train-sync", - vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=meta_batch_size // 10, - env_num=idx % (meta_batch_size // 10), - seed=None if not seed else seed + (idx // (meta_batch_size // 10)), - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML10_V3["train"].keys()) - * (meta_batch_size // 10) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML10-test-sync", - vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_ml_test, - "ML1-test-" + env_name, - tasks_per_env=meta_batch_size // 5, - env_num=idx % (meta_batch_size // 5), - seed=None if not seed else seed + (idx // (meta_batch_size // 5)), - total_tasks_per_cls=40, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML10_V3["test"].keys()) * (meta_batch_size // 5) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML10-train-async", - vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=meta_batch_size // 10, - env_num=idx % (meta_batch_size // 10), - seed=None if not seed else seed + (idx // (meta_batch_size // 10)), - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML10_V3["train"].keys()) - * (meta_batch_size // 10) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML10-test-async", - vector_entry_point=lambda seed=None, meta_batch_size=20, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_ml_test, - "ML1-test-" + env_name, - tasks_per_env=meta_batch_size // 5, - env_num=idx % (meta_batch_size // 5), - seed=None if not seed else seed + (idx // (meta_batch_size // 5)), - total_tasks_per_cls=40, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML10_V3["test"].keys()) * (meta_batch_size // 5) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML45-train-sync", - vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=meta_batch_size // 45, - env_num=idx % (meta_batch_size // 45), - seed=None if not seed else seed + (idx // (meta_batch_size // 45)), - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML45_V3["train"].keys()) - * (meta_batch_size // 45) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML45-test-sync", - vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_ml_test, - "ML1-test-" + env_name, - tasks_per_env=meta_batch_size // 5, - env_num=idx % (meta_batch_size // 5), - seed=None if not seed else seed + (idx // (meta_batch_size // 5)), - total_tasks_per_cls=45, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML45_V3["test"].keys()) * (meta_batch_size // 5) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML45-train-async", - vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=meta_batch_size // 45, - env_num=idx % (meta_batch_size // 45), - seed=None if not seed else seed + (idx // (meta_batch_size // 45)), - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML45_V3["train"].keys()) - * (meta_batch_size // 45) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/ML45-test-async", - vector_entry_point=lambda seed=None, meta_batch_size=45, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_ml_test, - "ML1-test-" + env_name, - tasks_per_env=meta_batch_size // 5, - env_num=idx % (meta_batch_size // 5), - seed=None if not seed else seed + (idx // (meta_batch_size // 5)), - total_tasks_per_cls=45, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate( - sorted( - list(_env_dict.ML45_V3["test"].keys()) * (meta_batch_size // 5) - ) - ) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/custom-mt-envs-sync", - vector_entry_point=lambda seed=None, use_one_hot=False, envs_list=None, num_envs=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=len(envs_list), - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(envs_list) - ] - ), - kwargs=kwargs, - ) + register( + id=f"Meta-World/custom-mt-envs-{vector_strategy}", + vector_entry_point=partial(_custom_mt_vector_entry_point, vector_strategy), + kwargs={}, + ) - register( - id="Meta-World/custom-mt-envs-async", - vector_entry_point=lambda seed=None, use_one_hot=False, envs_list=None, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_mt, - "MT1-" + env_name, - num_tasks=len(envs_list), - env_id=idx, - seed=None if not seed else seed + idx, - use_one_hot=use_one_hot, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(envs_list) - ] - ), - kwargs=kwargs, - ) + for vector_strategy in ["sync", "async"]: + + def _custom_ml_vector_entry_point( + vector_strategy: str, + train_envs: list[str], + test_envs: list[str], + meta_batch_size: int = 20, + seed=None, + num_envs=None, + *args, + **lamb_kwargs, + ): + return _make_ml_envs_inner( # type: ignore + CustomML(train_envs, test_envs, seed=seed), + meta_batch_size=meta_batch_size, + vector_strategy=vector_strategy, # type: ignore + *args, + **lamb_kwargs, + ) - register( - id="Meta-World/custom-ml-envs-sync", - vector_entry_point=lambda envs_list, seed=None, num_envs=None, meta_batch_size=None, *args, **lamb_kwargs: gym.vector.SyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=1, - env_num=0, - seed=None if not seed else seed + idx, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(envs_list) - ] - ), - kwargs=kwargs, - ) - - register( - id="Meta-World/custom-ml-envs-async", - vector_entry_point=lambda envs_list, seed=None, meta_batch_size=None, num_envs=None, *args, **lamb_kwargs: gym.vector.AsyncVectorEnv( - [ - partial( - make_single_ml_train, - "ML1-train-" + env_name, - tasks_per_env=1, - env_num=0, - seed=None if not seed else seed + idx, - *args, - **lamb_kwargs, - ) - for idx, env_name in enumerate(envs_list) - ] - ), - kwargs=kwargs, - ) + register( + id=f"Meta-World/custom-ml-envs-{vector_strategy}", + vector_entry_point=partial(_custom_ml_vector_entry_point, vector_strategy), + kwargs={}, + ) register_mw_envs() diff --git a/metaworld/evaluation.py b/metaworld/evaluation.py new file mode 100644 index 00000000..21b97270 --- /dev/null +++ b/metaworld/evaluation.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from typing import NamedTuple, Protocol + +import gymnasium as gym +import numpy as np +import numpy.typing as npt + +from metaworld.env_dict import ALL_V3_ENVIRONMENTS + + +class Agent(Protocol): + def eval_action( + self, obs: npt.NDArray[np.float64] + ) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]: + ... + + +class MetaLearningAgent(Agent): + def adapt(self, rollouts: Rollout) -> None: + ... + + +def _get_task_names( + envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, +) -> list[str]: + metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()} + return [ + metaworld_cls_to_task_name[task_name] + for task_name in envs.get_attr("task_name") + ] + + +def evaluation( + agent: Agent, + eval_envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, + num_episodes: int = 50, +) -> tuple[float, float, dict[str, float]]: + terminate_on_success = np.all(eval_envs.get_attr("terminate_on_success")).item() + eval_envs.call("toggle_terminate_on_success", True) + + obs: npt.NDArray[np.float64] + obs, _ = eval_envs.reset() + task_names = _get_task_names(eval_envs) + successes = {task_name: 0 for task_name in set(task_names)} + episodic_returns: dict[str, list[float]] = { + task_name: [] for task_name in set(task_names) + } + + def eval_done(returns): + return all(len(r) >= num_episodes for _, r in returns.items()) + + while not eval_done(episodic_returns): + actions, _ = agent.eval_action(obs) + obs, _, terminations, truncations, infos = eval_envs.step(actions) + for i, env_ended in enumerate(np.logical_or(terminations, truncations)): + if env_ended: + episodic_returns[task_names[i]].append(float(infos["episode"]["r"][i])) + if len(episodic_returns[task_names[i]]) <= num_episodes: + successes[task_names[i]] += int(infos["success"][i]) + + episodic_returns = { + task_name: returns[:num_episodes] + for task_name, returns in episodic_returns.items() + } + + success_rate_per_task = { + task_name: task_successes / num_episodes + for task_name, task_successes in successes.items() + } + mean_success_rate = np.mean(list(success_rate_per_task.values())) + mean_returns = np.mean(list(episodic_returns.values())) + + eval_envs.call("toggle_terminate_on_success", terminate_on_success) + + return float(mean_success_rate), float(mean_returns), success_rate_per_task + + +def metalearning_evaluation( + agent: MetaLearningAgent, + eval_envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, + adaptation_steps: int = 1, + max_episode_steps: int = 500, + adaptation_episodes: int = 10, + num_episodes: int = 50, + num_evals: int = 1, +) -> tuple[float, float, dict[str, float]]: + task_names = _get_task_names(eval_envs) + + total_mean_success_rate = 0.0 + total_mean_return = 0.0 + + success_rate_per_task = np.zeros((num_evals, len(set(task_names)))) + + for i in range(num_evals): + eval_envs.call("toggle_sample_tasks_on_reset", False) + eval_envs.call("toggle_terminate_on_success", False) + eval_envs.call("sample_tasks") + obs: npt.NDArray[np.float64] + obs, _ = eval_envs.reset() + obs = np.stack(obs) # type: ignore + has_autoreset = np.full((eval_envs.num_envs,), False) + eval_buffer = _MultiTaskRolloutBuffer( + num_tasks=eval_envs.num_envs, + rollouts_per_task=adaptation_episodes, + max_episode_steps=max_episode_steps, + ) + + for _ in range(adaptation_steps): + while not eval_buffer.ready: + actions, aux_policy_outs = agent.eval_action(obs) + next_obs: npt.NDArray[np.float64] + rewards: npt.NDArray[np.float64] + next_obs, rewards, terminations, truncations, _ = eval_envs.step( + actions + ) + if not has_autoreset.any(): + eval_buffer.push( + obs, + actions, + rewards, + truncations, + log_probs=aux_policy_outs.get("log_probs"), + means=aux_policy_outs.get("means"), + stds=aux_policy_outs.get("stds"), + ) + has_autoreset = np.logical_or(terminations, truncations) + obs = next_obs + + rollouts = eval_buffer.get() + agent.adapt(rollouts) + eval_buffer.reset() + + # Evaluation + mean_success_rate, mean_return, _success_rate_per_task = evaluation( + agent, eval_envs, num_episodes + ) + total_mean_success_rate += mean_success_rate + total_mean_return += mean_return + success_rate_per_task[i] = np.array(list(_success_rate_per_task.values())) + + success_rates = (success_rate_per_task).mean(axis=0) + task_success_rates = { + task_name: success_rates[i] for i, task_name in enumerate(set(task_names)) + } + + return ( + total_mean_success_rate / num_evals, + total_mean_return / num_evals, + task_success_rates, + ) + + +class Rollout(NamedTuple): + observations: npt.NDArray + actions: npt.NDArray + rewards: npt.NDArray + dones: npt.NDArray + + # Auxiliary policy outputs + log_probs: npt.NDArray | None = None + means: npt.NDArray | None = None + stds: npt.NDArray | None = None + + +class _MultiTaskRolloutBuffer: + """A buffer to accumulate rollouts for multiple tasks. + Useful for ML1, ML10, ML45, or on-policy MTRL algorithms. + + In Metaworld, all episodes are as long as the time limit (typically 500), thus in this buffer we assume + fixed-length episodes and leverage that for optimisations.""" + + rollouts: list[list[Rollout]] + + def __init__( + self, + num_tasks: int, + rollouts_per_task: int, + max_episode_steps: int, + ): + self.num_tasks = num_tasks + self._rollouts_per_task = rollouts_per_task + self._max_episode_steps = max_episode_steps + + self.reset() + + def reset(self): + """Reset the buffer.""" + self.rollouts = [[] for _ in range(self.num_tasks)] + self._running_rollouts = [[] for _ in range(self.num_tasks)] + + @property + def ready(self) -> bool: + """Returns whether or not a full batch of rollouts for each task has been sampled.""" + return all(len(t) == self._rollouts_per_task for t in self.rollouts) + + def get_single_task( + self, + task_idx: int, + ) -> Rollout: + """Compute returns and advantages for the collected rollouts. + + Returns a Rollout tuple for a single task where each array has the batch dimensions (Timestep,). + The timesteps are multiple rollouts flattened into one time dimension.""" + assert task_idx < self.num_tasks, "Task index out of bounds." + + task_rollouts = Rollout( + *map(lambda *xs: np.stack(xs), *self.rollouts[task_idx]) + ) + + assert task_rollouts.observations.shape[:2] == ( + self._rollouts_per_task, + self._max_episode_steps, + ), "Buffer does not have the expected amount of data before sampling." + + return task_rollouts + + def get( + self, + ) -> Rollout: + """Compute returns and advantages for the collected rollouts. + + Returns a Rollout tuple where each array has the batch dimensions (Task,Timestep,). + The timesteps are multiple rollouts flattened into one time dimension.""" + rollouts_per_task = [ + Rollout(*map(lambda *xs: np.stack(xs), *t)) for t in self.rollouts + ] + all_rollouts = Rollout(*map(lambda *xs: np.stack(xs), *rollouts_per_task)) + assert all_rollouts.observations.shape[:3] == ( + self.num_tasks, + self._rollouts_per_task, + self._max_episode_steps, + ), "Buffer does not have the expected amount of data before sampling." + + return all_rollouts + + def push( + self, + obs: npt.NDArray, + actions: npt.NDArray, + rewards: npt.NDArray, + dones: npt.NDArray, + log_probs: npt.NDArray | None = None, + means: npt.NDArray | None = None, + stds: npt.NDArray | None = None, + ): + """Add a batch of timesteps to the buffer. Multiple batch dims are supported, but they + need to multiply to the buffer's meta batch size. + + If an episode finishes here for any of the envs, pop the full rollout into the rollout buffer. + """ + assert np.prod(rewards.shape) == self.num_tasks + + obs = obs.copy() + actions = actions.copy() + assert obs.ndim == actions.ndim + if ( + obs.ndim > 2 and actions.ndim > 2 + ): # Flatten outer batch dims only if they exist + obs = obs.reshape(-1, *obs.shape[2:]) + actions = actions.reshape(-1, *actions.shape[2:]) + + rewards = rewards.reshape(-1, 1).copy() + dones = dones.reshape(-1, 1).copy() + if log_probs is not None: + log_probs = log_probs.reshape(-1, 1).copy() + if means is not None: + means = means.copy() + if means.ndim > 2: + means = means.reshape(-1, *means.shape[2:]) + if stds is not None: + stds = stds.copy() + if stds.ndim > 2: + stds = stds.reshape(-1, *stds.shape[2:]) + + for i in range(self.num_tasks): + timestep: tuple[npt.NDArray, ...] = ( + obs[i], + actions[i], + rewards[i], + dones[i], + ) + if log_probs is not None: + timestep += (log_probs[i],) + if means is not None: + timestep += (means[i],) + if stds is not None: + timestep += (stds[i],) + self._running_rollouts[i].append(timestep) + + if dones[i]: # pop full rollouts into the rollouts buffer + rollout = Rollout( + *map(lambda *xs: np.stack(xs), *self._running_rollouts[i]) + ) + self.rollouts[i].append(rollout) + self._running_rollouts[i] = [] diff --git a/metaworld/policies/__init__.py b/metaworld/policies/__init__.py index bbe1285b..37a2a1c6 100644 --- a/metaworld/policies/__init__.py +++ b/metaworld/policies/__init__.py @@ -73,6 +73,61 @@ from metaworld.policies.sawyer_window_close_v3_policy import SawyerWindowCloseV3Policy from metaworld.policies.sawyer_window_open_v3_policy import SawyerWindowOpenV3Policy +ENV_POLICY_MAP = dict( + { + "assembly-v3": SawyerAssemblyV3Policy, + "basketball-v3": SawyerBasketballV3Policy, + "bin-picking-v3": SawyerBinPickingV3Policy, + "box-close-v3": SawyerBoxCloseV3Policy, + "button-press-topdown-v3": SawyerButtonPressTopdownV3Policy, + "button-press-topdown-wall-v3": SawyerButtonPressTopdownWallV3Policy, + "button-press-v3": SawyerButtonPressV3Policy, + "button-press-wall-v3": SawyerButtonPressWallV3Policy, + "coffee-button-v3": SawyerCoffeeButtonV3Policy, + "coffee-pull-v3": SawyerCoffeePullV3Policy, + "coffee-push-v3": SawyerCoffeePushV3Policy, + "dial-turn-v3": SawyerDialTurnV3Policy, + "disassemble-v3": SawyerDisassembleV3Policy, + "door-close-v3": SawyerDoorCloseV3Policy, + "door-lock-v3": SawyerDoorLockV3Policy, + "door-open-v3": SawyerDoorOpenV3Policy, + "door-unlock-v3": SawyerDoorUnlockV3Policy, + "drawer-close-v3": SawyerDrawerCloseV3Policy, + "drawer-open-v3": SawyerDrawerOpenV3Policy, + "faucet-close-v3": SawyerFaucetCloseV3Policy, + "faucet-open-v3": SawyerFaucetOpenV3Policy, + "hammer-v3": SawyerHammerV3Policy, + "hand-insert-v3": SawyerHandInsertV3Policy, + "handle-press-side-v3": SawyerHandlePressSideV3Policy, + "handle-press-v3": SawyerHandlePressV3Policy, + "handle-pull-v3": SawyerHandlePullV3Policy, + "handle-pull-side-v3": SawyerHandlePullSideV3Policy, + "peg-insert-side-v3": SawyerPegInsertionSideV3Policy, + "lever-pull-v3": SawyerLeverPullV3Policy, + "peg-unplug-side-v3": SawyerPegUnplugSideV3Policy, + "pick-out-of-hole-v3": SawyerPickOutOfHoleV3Policy, + "pick-place-v3": SawyerPickPlaceV3Policy, + "pick-place-wall-v3": SawyerPickPlaceWallV3Policy, + "plate-slide-back-side-v3": SawyerPlateSlideBackSideV3Policy, + "plate-slide-back-v3": SawyerPlateSlideBackV3Policy, + "plate-slide-side-v3": SawyerPlateSlideSideV3Policy, + "plate-slide-v3": SawyerPlateSlideV3Policy, + "reach-v3": SawyerReachV3Policy, + "reach-wall-v3": SawyerReachWallV3Policy, + "push-back-v3": SawyerPushBackV3Policy, + "push-v3": SawyerPushV3Policy, + "push-wall-v3": SawyerPushWallV3Policy, + "shelf-place-v3": SawyerShelfPlaceV3Policy, + "soccer-v3": SawyerSoccerV3Policy, + "stick-pull-v3": SawyerStickPullV3Policy, + "stick-push-v3": SawyerStickPushV3Policy, + "sweep-into-v3": SawyerSweepIntoV3Policy, + "sweep-v3": SawyerSweepV3Policy, + "window-close-v3": SawyerWindowCloseV3Policy, + "window-open-v3": SawyerWindowOpenV3Policy, + } +) + __all__ = [ "SawyerAssemblyV3Policy", "SawyerBasketballV3Policy", @@ -124,4 +179,5 @@ "SawyerSweepV3Policy", "SawyerWindowOpenV3Policy", "SawyerWindowCloseV3Policy", + "ENV_POLICY_MAP", ] diff --git a/metaworld/wrappers.py b/metaworld/wrappers.py index f14b8bce..2128b991 100644 --- a/metaworld/wrappers.py +++ b/metaworld/wrappers.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from __future__ import annotations import gymnasium as gym import numpy as np @@ -12,6 +12,7 @@ class OneHotWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): def __init__(self, env: Env, task_idx: int, num_tasks: int): gym.utils.RecordConstructorArgs.__init__(self) gym.ObservationWrapper.__init__(self, env) + assert isinstance(env.observation_space, gym.spaces.Box) env_lb = env.observation_space.low env_ub = env.observation_space.high one_hot_ub = np.ones(num_tasks) @@ -24,10 +25,6 @@ def __init__(self, env: Env, task_idx: int, num_tasks: int): np.concatenate([env_lb, one_hot_lb]), np.concatenate([env_ub, one_hot_ub]) ) - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - def observation(self, obs: NDArray) -> NDArray: return np.concatenate([obs, self.one_hot]) @@ -36,7 +33,7 @@ class RandomTaskSelectWrapper(gym.Wrapper): """A Gymnasium Wrapper to automatically set / reset the environment to a random task.""" - tasks: List[Task] + tasks: list[Task] sample_tasks_on_reset: bool = True def _set_random_task(self): @@ -46,29 +43,22 @@ def _set_random_task(self): def __init__( self, env: Env, - tasks: List[Task], + tasks: list[Task], sample_tasks_on_reset: bool = True, - seed: Optional[int] = None, ): super().__init__(env) self.tasks = tasks self.sample_tasks_on_reset = sample_tasks_on_reset - if seed: - self.unwrapped.seed(seed) def toggle_sample_tasks_on_reset(self, on: bool): self.sample_tasks_on_reset = on - def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def reset(self, *, seed: int | None = None, options: dict | None = None): if self.sample_tasks_on_reset: self._set_random_task() return self.env.reset(seed=seed, options=options) - def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def sample_tasks(self, *, seed: int | None = None, options: dict | None = None): self._set_random_task() return self.env.reset(seed=seed, options=options) @@ -82,14 +72,14 @@ class PseudoRandomTaskSelectWrapper(gym.Wrapper): Doesn't sample new tasks on reset by default. """ - tasks: List[Task] + tasks: list[Task] current_task_idx: int sample_tasks_on_reset: bool = False def _set_pseudo_random_task(self): self.current_task_idx = (self.current_task_idx + 1) % len(self.tasks) if self.current_task_idx == 0: - np.random.shuffle(self.tasks) + self.np_random.shuffle(self.tasks) self.unwrapped.set_task(self.tasks[self.current_task_idx]) def toggle_sample_tasks_on_reset(self, on: bool): @@ -98,27 +88,20 @@ def toggle_sample_tasks_on_reset(self, on: bool): def __init__( self, env: Env, - tasks: List[Task], + tasks: list[Task], sample_tasks_on_reset: bool = False, - seed: Optional[int] = None, ): super().__init__(env) self.sample_tasks_on_reset = sample_tasks_on_reset self.tasks = tasks self.current_task_idx = -1 - if seed: - np.random.seed(seed) - def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def reset(self, *, seed: int | None = None, options: dict | None = None): if self.sample_tasks_on_reset: self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) - def sample_tasks( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ): + def sample_tasks(self, *, seed: int | None = None, options: dict | None = None): self._set_pseudo_random_task() return self.env.reset(seed=seed, options=options) diff --git a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py index 6db06cf0..96c50cd2 100644 --- a/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py +++ b/tests/metaworld/envs/mujoco/sawyer_xyz/test_scripted_policies.py @@ -1,120 +1,22 @@ +import random + +import numpy as np import pytest from metaworld import MT1 -from metaworld.policies import ( - SawyerAssemblyV3Policy, - SawyerBasketballV3Policy, - SawyerBinPickingV3Policy, - SawyerBoxCloseV3Policy, - SawyerButtonPressTopdownV3Policy, - SawyerButtonPressTopdownWallV3Policy, - SawyerButtonPressV3Policy, - SawyerButtonPressWallV3Policy, - SawyerCoffeeButtonV3Policy, - SawyerCoffeePullV3Policy, - SawyerCoffeePushV3Policy, - SawyerDialTurnV3Policy, - SawyerDisassembleV3Policy, - SawyerDoorCloseV3Policy, - SawyerDoorLockV3Policy, - SawyerDoorOpenV3Policy, - SawyerDoorUnlockV3Policy, - SawyerDrawerCloseV3Policy, - SawyerDrawerOpenV3Policy, - SawyerFaucetCloseV3Policy, - SawyerFaucetOpenV3Policy, - SawyerHammerV3Policy, - SawyerHandInsertV3Policy, - SawyerHandlePressSideV3Policy, - SawyerHandlePressV3Policy, - SawyerHandlePullSideV3Policy, - SawyerHandlePullV3Policy, - SawyerLeverPullV3Policy, - SawyerPegInsertionSideV3Policy, - SawyerPegUnplugSideV3Policy, - SawyerPickOutOfHoleV3Policy, - SawyerPickPlaceV3Policy, - SawyerPickPlaceWallV3Policy, - SawyerPlateSlideBackSideV3Policy, - SawyerPlateSlideBackV3Policy, - SawyerPlateSlideSideV3Policy, - SawyerPlateSlideV3Policy, - SawyerPushBackV3Policy, - SawyerPushV3Policy, - SawyerPushWallV3Policy, - SawyerReachV3Policy, - SawyerReachWallV3Policy, - SawyerShelfPlaceV3Policy, - SawyerSoccerV3Policy, - SawyerStickPullV3Policy, - SawyerStickPushV3Policy, - SawyerSweepIntoV3Policy, - SawyerSweepV3Policy, - SawyerWindowCloseV3Policy, - SawyerWindowOpenV3Policy, -) - -policies = dict( - { - "assembly-v3": SawyerAssemblyV3Policy, - "basketball-v3": SawyerBasketballV3Policy, - "bin-picking-v3": SawyerBinPickingV3Policy, - "box-close-v3": SawyerBoxCloseV3Policy, - "button-press-topdown-v3": SawyerButtonPressTopdownV3Policy, - "button-press-topdown-wall-v3": SawyerButtonPressTopdownWallV3Policy, - "button-press-v3": SawyerButtonPressV3Policy, - "button-press-wall-v3": SawyerButtonPressWallV3Policy, - "coffee-button-v3": SawyerCoffeeButtonV3Policy, - "coffee-pull-v3": SawyerCoffeePullV3Policy, - "coffee-push-v3": SawyerCoffeePushV3Policy, - "dial-turn-v3": SawyerDialTurnV3Policy, - "disassemble-v3": SawyerDisassembleV3Policy, - "door-close-v3": SawyerDoorCloseV3Policy, - "door-lock-v3": SawyerDoorLockV3Policy, - "door-open-v3": SawyerDoorOpenV3Policy, - "door-unlock-v3": SawyerDoorUnlockV3Policy, - "drawer-close-v3": SawyerDrawerCloseV3Policy, - "drawer-open-v3": SawyerDrawerOpenV3Policy, - "faucet-close-v3": SawyerFaucetCloseV3Policy, - "faucet-open-v3": SawyerFaucetOpenV3Policy, - "hammer-v3": SawyerHammerV3Policy, - "hand-insert-v3": SawyerHandInsertV3Policy, - "handle-press-side-v3": SawyerHandlePressSideV3Policy, - "handle-press-v3": SawyerHandlePressV3Policy, - "handle-pull-v3": SawyerHandlePullV3Policy, - "handle-pull-side-v3": SawyerHandlePullSideV3Policy, - "peg-insert-side-v3": SawyerPegInsertionSideV3Policy, - "lever-pull-v3": SawyerLeverPullV3Policy, - "peg-unplug-side-v3": SawyerPegUnplugSideV3Policy, - "pick-out-of-hole-v3": SawyerPickOutOfHoleV3Policy, - "pick-place-v3": SawyerPickPlaceV3Policy, - "pick-place-wall-v3": SawyerPickPlaceWallV3Policy, - "plate-slide-back-side-v3": SawyerPlateSlideBackSideV3Policy, - "plate-slide-back-v3": SawyerPlateSlideBackV3Policy, - "plate-slide-side-v3": SawyerPlateSlideSideV3Policy, - "plate-slide-v3": SawyerPlateSlideV3Policy, - "reach-v3": SawyerReachV3Policy, - "reach-wall-v3": SawyerReachWallV3Policy, - "push-back-v3": SawyerPushBackV3Policy, - "push-v3": SawyerPushV3Policy, - "push-wall-v3": SawyerPushWallV3Policy, - "shelf-place-v3": SawyerShelfPlaceV3Policy, - "soccer-v3": SawyerSoccerV3Policy, - "stick-pull-v3": SawyerStickPullV3Policy, - "stick-push-v3": SawyerStickPushV3Policy, - "sweep-into-v3": SawyerSweepIntoV3Policy, - "sweep-v3": SawyerSweepV3Policy, - "window-close-v3": SawyerWindowCloseV3Policy, - "window-open-v3": SawyerWindowOpenV3Policy, - } -) +from metaworld.policies import ENV_POLICY_MAP @pytest.mark.parametrize("env_name", MT1.ENV_NAMES) def test_policy(env_name): - mt1 = MT1(env_name) + SEED = 42 + random.seed(SEED) + np.random.random(SEED) + + mt1 = MT1(env_name, seed=SEED) env = mt1.train_classes[env_name]() - p = policies[env_name]() + env.seed(SEED) + p = ENV_POLICY_MAP[env_name]() completed = 0 for task in mt1.train_tasks: env.set_task(task) diff --git a/tests/metaworld/test_evaluation.py b/tests/metaworld/test_evaluation.py new file mode 100644 index 00000000..9ef630f4 --- /dev/null +++ b/tests/metaworld/test_evaluation.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import random + +import gymnasium as gym +import numpy as np +import numpy.typing as npt +import pytest + +import metaworld # noqa: F401 +from metaworld import evaluation +from metaworld.policies import ENV_POLICY_MAP + + +class ScriptedPolicyAgent(evaluation.MetaLearningAgent): + def __init__( + self, + envs: gym.vector.SyncVectorEnv | gym.vector.AsyncVectorEnv, + num_rollouts: int | None = None, + max_episode_steps: int | None = None, + ): + env_task_names = evaluation._get_task_names(envs) + self.policies = [ENV_POLICY_MAP[task]() for task in env_task_names] # type: ignore + self.num_rollouts = num_rollouts + self.max_episode_steps = max_episode_steps + self.adapt_calls = 0 + + def eval_action( + self, obs: npt.NDArray[np.float64] + ) -> tuple[npt.NDArray[np.float64], dict[str, npt.NDArray]]: + actions: list[npt.NDArray[np.float32]] = [] + num_envs = len(self.policies) + for env_idx in range(num_envs): + actions.append(self.policies[env_idx].get_action(obs[env_idx])) + stacked_actions = np.stack(actions, axis=0, dtype=np.float64) + return stacked_actions, { + "log_probs": np.ones((num_envs,)), + "means": stacked_actions, + "stds": np.zeros((num_envs,)), + } + + def adapt(self, rollouts: evaluation.Rollout) -> None: + assert self.num_rollouts is not None + + for key in [ + "observations", + "rewards", + "actions", + "dones", + "log_probs", + "means", + "stds", + ]: + assert len(getattr(rollouts, key).shape) >= 3 + assert getattr(rollouts, key).shape[0] == len(self.policies) + assert getattr(rollouts, key).shape[1] == self.num_rollouts + assert getattr(rollouts, key).shape[2] == self.max_episode_steps + + self.adapt_calls += 1 + + +class RemovePartialObservabilityWrapper(gym.vector.VectorWrapper): + def get_attr(self, name): + return self.env.get_attr(name) + + def set_attr(self, name, values): + return self.env.set_attr(name, values) + + def call(self, name, *args, **kwargs): + return self.env.call(name, *args, **kwargs) + + def step(self, actions): + self.env.set_attr("_partially_observable", False) + return super().step(actions) + + +def test_evaluation(): + SEED = 42 + max_episode_steps = 300 # To speed up the test + num_episodes = 50 + + random.seed(SEED) + np.random.seed(SEED) + envs = gym.make_vec( + "Meta-World/MT50-async", seed=SEED, max_episode_steps=max_episode_steps + ) + agent = ScriptedPolicyAgent(envs) + mean_success_rate, mean_returns, success_rate_per_task = evaluation.evaluation( + agent, envs, num_episodes=num_episodes + ) + assert isinstance(mean_returns, float) + assert mean_success_rate >= 0.80 + assert len(success_rate_per_task) == envs.num_envs + assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80) + + +@pytest.mark.parametrize("benchmark", ("ML10", "ML45")) +def test_metalearning_evaluation(benchmark): + SEED = 42 + + max_episode_steps = 300 + meta_batch_size = 10 # Number of parallel envs + + adaptation_steps = 2 # Number of adaptation iterations + adaptation_episodes = 2 # Number of train episodes per task in meta_batch_size per adaptation iteration + num_evals = 50 # Number of different task vectors tested for each task + num_episodes = 1 # Number of test episodes per task vector + + random.seed(SEED) + np.random.seed(SEED) + envs = gym.make_vec( + f"Meta-World/{benchmark}-test-async", + seed=SEED, + meta_batch_size=meta_batch_size, + max_episode_steps=max_episode_steps, + ) + envs = RemovePartialObservabilityWrapper(envs) + agent = ScriptedPolicyAgent(envs, adaptation_episodes, max_episode_steps) + ( + mean_success_rate, + mean_returns, + success_rate_per_task, + ) = evaluation.metalearning_evaluation( + agent, + envs, + max_episode_steps=max_episode_steps, + num_episodes=num_episodes, + adaptation_episodes=adaptation_episodes, + adaptation_steps=adaptation_steps, + num_evals=num_evals, + ) + assert isinstance(mean_returns, float) + assert mean_success_rate >= 0.80 + assert len(success_rate_per_task) == len(set(evaluation._get_task_names(envs))) + assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80) + assert agent.adapt_calls == num_evals * adaptation_steps diff --git a/tests/metaworld/test_gym_make.py b/tests/metaworld/test_gym_make.py index 6fe0e2c9..33ee9211 100644 --- a/tests/metaworld/test_gym_make.py +++ b/tests/metaworld/test_gym_make.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random from typing import Literal @@ -159,7 +161,11 @@ def test_ml_benchmarks( vector_strategy: str, ): meta_batch_size = 20 if benchmark != "ML45" else 45 - total_tasks_per_cls = _N_GOALS if benchmark != "ML45" else 45 + total_tasks_per_cls = _N_GOALS + if benchmark == "ML45": + total_tasks_per_cls = 45 + elif benchmark == "ML10" and split == "test": + total_tasks_per_cls = 40 max_episode_steps = 10 envs = gym.make_vec(