From a0e7a65554ae64a13b9a3cda1aac379c5ff4bc73 Mon Sep 17 00:00:00 2001 From: Taewoon Kim Date: Thu, 13 Oct 2022 22:40:39 +0200 Subject: [PATCH] 0.2.2 ready --- README.md | 2 +- documents/README-v0.md | 2 +- documents/README-v1.md | 2 +- room-env-v1.ipynb | 100 +++++++++++++++++++++++---------------- room_env/envs/room0.py | 6 ++- room_env/envs/room1.py | 6 ++- room_env/envs/room2.py | 13 ++--- room_env/utils.py | 23 +++++---- setup.cfg | 2 +- test/test_room_env_v0.py | 2 +- test/test_room_env_v1.py | 2 +- test/test_room_env_v2.py | 12 +++-- 12 files changed, 103 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index 10b3d9f..20dd10b 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ import room_env env = gym.make("RoomEnv-v2") observation, info = env.reset() while True: - observation, reward, done, info = env.step(0) + observation, reward, done, truncated, info = env.step(0) if done: break ``` diff --git a/documents/README-v0.md b/documents/README-v0.md index bfc38c6..fc732e2 100644 --- a/documents/README-v0.md +++ b/documents/README-v0.md @@ -89,7 +89,7 @@ import room_env env = gym.make("RoomEnv-v0") (observation, question), info = env.reset() while True: - (observation, question), reward, done, info = env.step("This is my answer!") + (observation, question), reward, done, truncated, info = env.step("This is my answer!") if done: break ``` diff --git a/documents/README-v1.md b/documents/README-v1.md index 4f2fe21..717238a 100644 --- a/documents/README-v1.md +++ b/documents/README-v1.md @@ -65,7 +65,7 @@ import room_env env = gym.make("RoomEnv-v1") (observation, question), info = env.reset() while True: - (observation, question), reward, done, info = env.step("This is my answer!") + (observation, question), reward, done, truncated, info = env.step("This is my answer!") if done: break ``` diff --git a/room-env-v1.ipynb b/room-env-v1.ipynb index 5a46432..03c0472 100644 --- a/room-env-v1.ipynb +++ b/room-env-v1.ipynb @@ -9,7 +9,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/tk/.virtualenvs/dev-python3.8/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:174: UserWarning: \u001b[33mWARN: Future gym versions will require that `Env.reset` can be passed a `seed` instead of using `Env.seed` for resetting the environment random number generator.\u001b[0m\n", + " 0%| | 0/6 [00:00\u001b[0m\n", " logger.warn(f\"{pre} should be an int or np.int64, actual type: {type(obs)}\")\n", "/home/tk/.virtualenvs/dev-python3.8/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:165: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method is not within the observation space.\u001b[0m\n", - " logger.warn(f\"{pre} is not within the observation space.\")\n" + " logger.warn(f\"{pre} is not within the observation space.\")\n", + "100%|██████████| 6/6 [01:14<00:00, 12.48s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "capacity=2, strategy=episodic,\tpre_sem=False,\trewards_mean=-11.0, rewards_std=4.266\n", - "capacity=2, strategy=semantic,\tpre_sem=False,\trewards_mean=-9.8, rewards_std=3.156\n", - "capacity=2, strategy=random,\tpre_sem=False,\trewards_mean=-10.0, rewards_std=3.521\n", - "capacity=2, strategy=pre_sem,\tpre_sem=True,\trewards_mean=-10.0, rewards_std=3.347\n", - "\n", - "capacity=4, strategy=episodic,\tpre_sem=False,\trewards_mean=-9.8, rewards_std=3.894\n", - "capacity=4, strategy=semantic,\tpre_sem=False,\trewards_mean=-7.2, rewards_std=3.709\n", - "capacity=4, strategy=random,\tpre_sem=False,\trewards_mean=-8.8, rewards_std=3.37\n", - "capacity=4, strategy=pre_sem,\tpre_sem=True,\trewards_mean=-9.2, rewards_std=4.118\n", - "\n", - "capacity=8, strategy=episodic,\tpre_sem=False,\trewards_mean=-8.8, rewards_std=3.816\n", - "capacity=8, strategy=semantic,\tpre_sem=False,\trewards_mean=-2.6, rewards_std=4.779\n", - "capacity=8, strategy=random,\tpre_sem=False,\trewards_mean=-7.6, rewards_std=3.878\n", - "capacity=8, strategy=pre_sem,\tpre_sem=True,\trewards_mean=-7.0, rewards_std=4.494\n", - "\n", - "capacity=16, strategy=episodic,\tpre_sem=False,\trewards_mean=-6.4, rewards_std=3.499\n", - "capacity=16, strategy=semantic,\tpre_sem=False,\trewards_mean=0.6, rewards_std=4.294\n", - "capacity=16, strategy=random,\tpre_sem=False,\trewards_mean=-4.4, rewards_std=4.03\n", - "capacity=16, strategy=pre_sem,\tpre_sem=True,\trewards_mean=-3.0, rewards_std=5.196\n", - "\n", - "capacity=32, strategy=episodic,\tpre_sem=False,\trewards_mean=-3.2, rewards_std=4.445\n", - "capacity=32, strategy=semantic,\tpre_sem=False,\trewards_mean=2.8, rewards_std=3.919\n", - "capacity=32, strategy=random,\tpre_sem=False,\trewards_mean=-2.0, rewards_std=4.561\n", - "capacity=32, strategy=pre_sem,\tpre_sem=True,\trewards_mean=7.2, rewards_std=3.919\n", - "\n", - "capacity=64, strategy=episodic,\tpre_sem=False,\trewards_mean=-1.2, rewards_std=4.578\n", - "capacity=64, strategy=semantic,\tpre_sem=False,\trewards_mean=3.0, rewards_std=4.626\n", - "capacity=64, strategy=random,\tpre_sem=False,\trewards_mean=-1.2, rewards_std=5.307\n", - "capacity=64, strategy=pre_sem,\tpre_sem=True,\trewards_mean=6.8, rewards_std=3.763\n", + "{2: {'episodic': {'mean': -97.2, 'std': 7.111},\n", + " 'pre_sem': {'mean': -88.8, 'std': 8.01},\n", + " 'random': {'mean': -97.8, 'std': 5.618},\n", + " 'semantic': {'mean': -85.0, 'std': 7.225}},\n", + " 4: {'episodic': {'mean': -84.2, 'std': 7.718},\n", + " 'pre_sem': {'mean': -75.0, 'std': 7.169},\n", + " 'random': {'mean': -79.6, 'std': 6.741},\n", + " 'semantic': {'mean': -57.6, 'std': 8.04}},\n", + " 8: {'episodic': {'mean': -62.0, 'std': 8.944},\n", + " 'pre_sem': {'mean': -47.2, 'std': 8.256},\n", + " 'random': {'mean': -45.0, 'std': 3.606},\n", + " 'semantic': {'mean': -10.4, 'std': 19.2}},\n", + " 16: {'episodic': {'mean': -19.8, 'std': 11.294},\n", + " 'pre_sem': {'mean': -5.2, 'std': 9.558},\n", + " 'random': {'mean': -6.8, 'std': 7.859},\n", + " 'semantic': {'mean': 40.8, 'std': 9.968}},\n", + " 32: {'episodic': {'mean': 50.4, 'std': 8.429},\n", + " 'pre_sem': {'mean': 87.6, 'std': 7.736},\n", + " 'random': {'mean': 37.0, 'std': 10.325},\n", + " 'semantic': {'mean': 55.6, 'std': 12.484}},\n", + " 64: {'episodic': {'mean': 128.0, 'std': 0.0},\n", + " 'pre_sem': {'mean': 107.0, 'std': 5.459},\n", + " 'random': {'mean': 52.6, 'std': 13.507},\n", + " 'semantic': {'mean': 57.8, 'std': 9.652}}}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "\n" ] } @@ -68,14 +70,15 @@ "logger = logging.getLogger()\n", "logger.disabled = True\n", "\n", - "from room_env.utils import print_handcrafted\n", + "from pprint import pprint\n", + "from room_env.utils import get_handcrafted\n", "\n", "\n", - "print_handcrafted(\n", + "results = get_handcrafted(\n", " env=\"RoomEnv-v2\",\n", " des_size=\"l\",\n", " seeds=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n", - " question_prob=0.1,\n", + " question_prob=1,\n", " policies={\n", " \"memory_management\": \"rl\",\n", " \"question_answer\": \"episodic_semantic\",\n", @@ -83,17 +86,32 @@ " },\n", " capacities=[2, 4, 8, 16, 32, 64],\n", " allow_random_human=False,\n", - " allow_random_question=True,\n", + " allow_random_question=False,\n", " varying_rewards=False,\n", " check_resources=True,\n", - ")" + " version=\"v1\",\n", + ")\n", + "pprint(results)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tk/.virtualenvs/dev-python3.8/lib/python3.8/site-packages/gym/envs/registration.py:555: UserWarning: \u001b[33mWARN: The environment RoomEnv-v1 is out of date. You should consider upgrading to version `v2`.\u001b[0m\n", + " logger.warn(\n", + "/home/tk/.virtualenvs/dev-python3.8/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:133: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method should be an int or np.int64, actual type: \u001b[0m\n", + " logger.warn(f\"{pre} should be an int or np.int64, actual type: {type(obs)}\")\n", + "/home/tk/.virtualenvs/dev-python3.8/lib/python3.8/site-packages/gym/utils/passive_env_checker.py:133: UserWarning: \u001b[33mWARN: The obs returned by the `step()` method should be an int or np.int64, actual type: \u001b[0m\n", + " logger.warn(f\"{pre} should be an int or np.int64, actual type: {type(obs)}\")\n" + ] + } + ], "source": [ "import gym\n", "import room_env\n", @@ -101,14 +119,14 @@ "env = gym.make(\"RoomEnv-v1\")\n", "(observation, question), info = env.reset()\n", "while True:\n", - " (observation, question), reward, done, info = env.step(\"foo\")\n", + " (observation, question), reward, done, truncated, info = env.step(\"foo\")\n", " if done:\n", " break" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -118,7 +136,7 @@ "env = gym.make(\"RoomEnv-v2\")\n", "observation, info = env.reset()\n", "while True:\n", - " observation, reward, done, info = env.step(0)\n", + " observation, reward, done, truncated, info = env.step(0)\n", " if done:\n", " break" ] diff --git a/room_env/envs/room0.py b/room_env/envs/room0.py index f558677..d761ab4 100644 --- a/room_env/envs/room0.py +++ b/room_env/envs/room0.py @@ -315,7 +315,7 @@ def renew(self) -> None: self.room = room - def step(self, action: str) -> Tuple[Tuple[dict, list], int, bool, dict]: + def step(self, action: str) -> Tuple[Tuple, int, bool, bool, dict]: """An agent takes an action. Args @@ -350,7 +350,9 @@ def step(self, action: str) -> Tuple[Tuple[dict, list], int, bool, dict]: else: done = False - return (observations, question), reward, done, info + truncated = False + + return (observations, question), reward, done, truncated, info def render(self, mode="console") -> None: if mode != "console": diff --git a/room_env/envs/room1.py b/room_env/envs/room1.py index bfc1fb8..8934b01 100644 --- a/room_env/envs/room1.py +++ b/room_env/envs/room1.py @@ -172,7 +172,7 @@ def reset(self) -> Tuple[dict, dict]: return (observation, question), info - def step(self, action: str) -> None: + def step(self, action: str) -> Tuple[Tuple, int, bool, bool, dict]: """An agent takes an action. Args @@ -200,7 +200,9 @@ def step(self, action: str) -> None: else: done = False - return (observation, question), reward, done, info + truncated = False + + return (observation, question), reward, done, truncated, info def render(self, mode="console") -> None: if mode != "console": diff --git a/room_env/envs/room2.py b/room_env/envs/room2.py index 61988d2..f28c924 100644 --- a/room_env/envs/room2.py +++ b/room_env/envs/room2.py @@ -373,7 +373,7 @@ def reset(self) -> dict: raise ValueError - def step(self, action: int) -> Tuple[dict, int, bool, dict]: + def step(self, action: int) -> Tuple[Tuple, int, bool, bool, dict]: """An agent takes an action. Args @@ -382,10 +382,11 @@ def step(self, action: int) -> Tuple[dict, int, bool, dict]: Returns ------- - state, reward, done, info + state, reward, done, truncated, info """ info = {} + truncated = False if self.policies["encoding"].lower() == "rl": # This is a dummy code self.obs = self.obs[action] @@ -412,7 +413,7 @@ def step(self, action: int) -> Tuple[dict, int, bool, dict]: else: done = False - return state, reward, done, info + return state, reward, done, truncated, info if self.policies["memory_management"].lower() == "rl": if action == 0: @@ -446,7 +447,7 @@ def step(self, action: int) -> Tuple[dict, int, bool, dict]: else: done = False - return state, reward, done, info + return state, reward, done, truncated, info if self.policies["question_answer"].lower() == "rl": if action == 0: @@ -476,7 +477,7 @@ def step(self, action: int) -> Tuple[dict, int, bool, dict]: if self.is_last: state = None done = True - return state, reward, done, info + return state, reward, done, truncated, info else: done = False @@ -488,7 +489,7 @@ def step(self, action: int) -> Tuple[dict, int, bool, dict]: "question": deepcopy(self.question), } - return state, reward, done, info + return state, reward, done, truncated, info def render(self, mode="console") -> None: if mode != "console": diff --git a/room_env/utils.py b/room_env/utils.py index 2ba0824..976fcaa 100644 --- a/room_env/utils.py +++ b/room_env/utils.py @@ -359,7 +359,7 @@ def run_des_seeds( action = 0 else: raise ValueError - state, reward, done, info = env.step(action) + state, reward, done, truncated, info = env.step(action) rewards += reward if done: break @@ -511,7 +511,7 @@ def fill_des_resources(des_size: str, version: str) -> None: des = RoomDes(des_size=des_size, check_resources=True) -def print_handcrafted( +def get_handcrafted( env: str = "RoomEnv-v2", des_size: str = "l", seeds: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -529,7 +529,7 @@ def print_handcrafted( check_resources: bool = True, version: str = "v2", ) -> None: - """Plot the env results with handcrafted policies. + """Get the env results with handcrafted policies. At the moment only {"memory_management": "rl"} is supported. @@ -550,11 +550,18 @@ def print_handcrafted( varying_rewards: If true, then the rewards are scaled in every episode so that total_episode_rewards is 128. version: Use v2 or v1. v2 recommended. + + Returns + ------- + handcrafted_results + """ how_to_forget = ["episodic", "semantic", "random", "pre_sem"] env_ = env + handcrafted_results = {} for capacity in capacities: + handcrafted_results[capacity] = {} for forget_short in how_to_forget: if forget_short == "random": @@ -602,7 +609,7 @@ def print_handcrafted( action = 0 else: raise ValueError - state, reward, done, info = env.step(action) + state, reward, done, truncated, info = env.step(action) rewards += reward if done: break @@ -610,8 +617,6 @@ def print_handcrafted( mean_ = np.mean(results).round(3).item() std_ = np.std(results).round(3).item() - print( - f"capacity={capacity}, strategy={forget_short}," - f"\tpre_sem={pretrain_semantic},\trewards_mean={mean_}, rewards_std={std_}" - ) - print() + handcrafted_results[capacity][forget_short] = {"mean": mean_, "std": std_} + + return handcrafted_results diff --git a/setup.cfg b/setup.cfg index 94ab2cf..11f01e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = room_env -version = 0.2.1 +version = 0.2.2 author = Taewoon Kim author_email = tae898@gmail.com description = The Room environment diff --git a/test/test_room_env_v0.py b/test/test_room_env_v0.py index 423bcd6..d0b5e89 100644 --- a/test/test_room_env_v0.py +++ b/test/test_room_env_v0.py @@ -11,6 +11,6 @@ def test_all(self) -> None: env = gym.make("RoomEnv-v0", room_size=room_size) observations, info = env.reset() while True: - observations, reward, done, info = env.step("foo") + observations, reward, done, truncated, info = env.step("foo") if done: break diff --git a/test/test_room_env_v1.py b/test/test_room_env_v1.py index 5d0b4da..3faf03f 100644 --- a/test/test_room_env_v1.py +++ b/test/test_room_env_v1.py @@ -57,6 +57,6 @@ def test_all(self) -> None: env = gym.make("RoomEnv-v1", des_size=des_size) (observations, question), info = env.reset() while True: - observations, reward, done, info = env.step(0) + observations, reward, done, truncated, info = env.step(0) if done: break diff --git a/test/test_room_env_v2.py b/test/test_room_env_v2.py index 3e91ecc..6e1ef91 100644 --- a/test/test_room_env_v2.py +++ b/test/test_room_env_v2.py @@ -33,7 +33,13 @@ def test_all(self) -> None: ) state, info = env.reset() while True: - state, reward, done, info = env.step(0) + ( + state, + reward, + done, + truncated, + info, + ) = env.step(0) if done: break @@ -121,7 +127,7 @@ def test_reset_qa(self) -> None: self.assertEqual(len(state["memory_systems"]["short"]), 0) while True: - state, reward, done, info = env.step(random.randint(0, 1)) + state, reward, done, truncated, info = env.step(random.randint(0, 1)) if done: break self.assertIn("episodic", state["memory_systems"]) @@ -151,7 +157,7 @@ def test_reset_memory_management(self) -> None: self.assertEqual(len(state["episodic"]), 0) self.assertEqual(len(state["short"]), 1) while True: - state, reward, done, info = env.step(random.randint(0, 2)) + state, reward, done, truncated, info = env.step(random.randint(0, 2)) if done: break self.assertIn("episodic", state)