From fcc00a0439f1cc96150a72605ae3af38b27df68a Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:09:21 -0400 Subject: [PATCH 1/6] Fix attribute error --- pokemonred_puffer/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 6102aa5..1984914 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -500,7 +500,7 @@ def check_if_in_bag_menu(self) -> bool: ) def check_if_action_in_bag_menu(self, action) -> bool: - return action == WindowEvent.PRESS_BUTTON_A and self.check_if_in_back_menu() + return action == WindowEvent.PRESS_BUTTON_A and self.check_if_in_bag_menu() def check_if_in_overworld(self) -> bool: return self.read_m(0xD057) == 0 and self.read_m(0xCF13) == 0 and self.read_m(0xFF8C) == 0 From ee634e373218ab933bd3cc1bc2702dfcb8cf6666 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:14:05 -0400 Subject: [PATCH 2/6] Add pre-commit --- .pre-commit-config.yaml | 9 +++++++++ pyproject.toml | 1 + 2 files changed, 10 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8f03df7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.2 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index da4fb6d..44a2158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ monitoring = [ "nvitop" ] dev = [ + "pre-commit", "ruff" ] From 81ee29d01762a74ac113ae470e1e1e21061defa8 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:14:46 -0400 Subject: [PATCH 3/6] Add pre-commit --- pokemonred_puffer/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index e3ab2c4..f71871e 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -204,7 +204,7 @@ def train( "-p", "--policy-name", default="multi_convolutional.MultiConvolutionalPolicy", - help="Policy module to use in policies", + help="Policy module to use in policies.", ) parser.add_argument( "-r", From c8fb101591a7f0ff5f18ca7066eddf8cfa00a2d7 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:23:16 -0400 Subject: [PATCH 4/6] Run ruff over the repo --- pokemonred_puffer/cleanrl_puffer.py | 8 +++----- pokemonred_puffer/environment.py | 2 -- pokemonred_puffer/global_map.py | 16 ++++++++++------ .../policies/multi_convolutional.py | 2 ++ pokemonred_puffer/resnet.py | 4 +++- pokemonred_puffer/wrappers/stream_wrapper.py | 4 ++-- 6 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 48a9df7..7481021 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -7,7 +7,6 @@ import uuid from collections import defaultdict from datetime import timedelta -import warnings import numpy as np import pufferlib @@ -145,7 +144,7 @@ def print_dashboard(stats, init_performance, performance): elif "time" in k: try: v = f"{v:.2f} s" - except: + except: # noqa pass first_word, *rest_words = k.split("_") @@ -343,7 +342,6 @@ def evaluate(self): self.log = False self.policy_pool.update_policies() - performance = defaultdict(list) env_profiler = pufferlib.utils.Profiler() inference_profiler = pufferlib.utils.Profiler() eval_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() @@ -494,7 +492,7 @@ def evaluate(self): self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) self.stats[k] = np.mean(v) self.max_stats[k] = np.max(v) - except: + except: # noqa continue if config.verbose: @@ -551,7 +549,7 @@ def train(self): self.b_obs = b_obs = torch.Tensor(self.obs_ary[b_idxs]) b_actions = torch.Tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) b_logprobs = torch.Tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) - b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) + # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) b_values = torch.Tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) b_advantages = advantages.reshape( config.batch_rows, num_minibatches, config.bptt_horizon diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 1984914..dc39eb7 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1,6 +1,4 @@ from abc import abstractmethod -import json -import os import random from collections import deque from multiprocessing import Lock, shared_memory diff --git a/pokemonred_puffer/global_map.py b/pokemonred_puffer/global_map.py index dd17a46..a8b9118 100644 --- a/pokemonred_puffer/global_map.py +++ b/pokemonred_puffer/global_map.py @@ -1,18 +1,22 @@ import os import json -MAP_PATH = os.path.join(os.path.dirname(__file__), 'map_data.json') +MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json") GLOBAL_MAP_SHAPE = (444, 436) with open(MAP_PATH) as map_data: - MAP_DATA = json.load(map_data)['regions'] -MAP_DATA = {int(e['id']): e for e in MAP_DATA} + MAP_DATA = json.load(map_data)["regions"] +MAP_DATA = {int(e["id"]): e for e in MAP_DATA} + # Handle KeyErrors def local_to_global(r: int, c: int, map_n: int): try: - map_x, map_y,= MAP_DATA[map_n]['coordinates'] + ( + map_x, + map_y, + ) = MAP_DATA[map_n]["coordinates"] return r + map_y, c + map_x except KeyError: - print(f'Map id {map_n} not found in map_data.json.') - return r + 0, c + 0 \ No newline at end of file + print(f"Map id {map_n} not found in map_data.json.") + return r + 0, c + 0 diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index b3c168b..7f3ada4 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -15,10 +15,12 @@ def one_hot(tensor, num_classes): torch.int64 ) + class RecurrentMultiConvolutionalWrapper(pufferlib.models.RecurrentWrapper): def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): super().__init__(env, policy, input_size, hidden_size, num_layers) + class MultiConvolutionalPolicy(pufferlib.models.Policy): def __init__( self, diff --git a/pokemonred_puffer/resnet.py b/pokemonred_puffer/resnet.py index ef8e7b8..cdeeaf2 100644 --- a/pokemonred_puffer/resnet.py +++ b/pokemonred_puffer/resnet.py @@ -166,7 +166,9 @@ def __init__( ) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/pokemonred_puffer/wrappers/stream_wrapper.py b/pokemonred_puffer/wrappers/stream_wrapper.py index 9590a9e..6fedb87 100644 --- a/pokemonred_puffer/wrappers/stream_wrapper.py +++ b/pokemonred_puffer/wrappers/stream_wrapper.py @@ -58,13 +58,13 @@ async def broadcast_ws_message(self, message): if self.websocket is not None: try: await self.websocket.send(message) - except websockets.exceptions.WebSocketException as e: + except websockets.exceptions.WebSocketException: self.websocket = None async def establish_wc_connection(self): try: self.websocket = await websockets.connect(self.ws_address) - except: + except: # noqa self.websocket = None def reset(self, *args, **kwargs): From 105cef273202b4d4f9928f65648c76da6b7d1872 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:25:10 -0400 Subject: [PATCH 5/6] Add pre-commit instructions --- README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6aa979d..f7eea26 100644 --- a/README.md +++ b/README.md @@ -74,4 +74,15 @@ To add rewards, add a new class to the `rewards` directory. Then update the `rew ### Adding Policies -To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps. \ No newline at end of file +To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps. + +## Development + +This repo uses [pre-commit](https://pre-commit.com/) to enforce formatting and linting. For development, please install this repo with: + +```sh +pip3 install -e '.[dev]' +pre-commit install +``` + +For any changes, please submit a PR. \ No newline at end of file From b5a0c0978c9cb64135f88f5cf27a9d8a1aa41e27 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:26:23 -0400 Subject: [PATCH 6/6] Add github actions --- .github/workflows/workflow.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/workflow.yml diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml new file mode 100644 index 0000000..d1f1821 --- /dev/null +++ b/.github/workflows/workflow.yml @@ -0,0 +1,32 @@ +name: tox +on: + push: + pull_request: + +jobs: + test: + name: test ${{ matrix.py }} - ${{ matrix.os }} + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: + - Ubuntu + - Windows + - MacOs + py: + - "3.10" + steps: + - name: Setup python for test ${{ matrix.py }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.py }} + - uses: actions/checkout@v3 + - name: Upgrade pip + run: python -m pip install -U pip + - name: Install + run: python -m pip install -e '.[dev]' + - name: Check formatting + run: ruff format pokemonred_puffer + - name: Check lint + run: ruff check pokemonred_puffer \ No newline at end of file