diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml new file mode 100644 index 0000000..d1f1821 --- /dev/null +++ b/.github/workflows/workflow.yml @@ -0,0 +1,32 @@ +name: tox +on: + push: + pull_request: + +jobs: + test: + name: test ${{ matrix.py }} - ${{ matrix.os }} + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: + - Ubuntu + - Windows + - MacOs + py: + - "3.10" + steps: + - name: Setup python for test ${{ matrix.py }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.py }} + - uses: actions/checkout@v3 + - name: Upgrade pip + run: python -m pip install -U pip + - name: Install + run: python -m pip install -e '.[dev]' + - name: Check formatting + run: ruff format pokemonred_puffer + - name: Check lint + run: ruff check pokemonred_puffer \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8f03df7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.3.2 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index 6aa979d..f7eea26 100644 --- a/README.md +++ b/README.md @@ -74,4 +74,15 @@ To add rewards, add a new class to the `rewards` directory. Then update the `rew ### Adding Policies -To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps. \ No newline at end of file +To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps. + +## Development + +This repo uses [pre-commit](https://pre-commit.com/) to enforce formatting and linting. For development, please install this repo with: + +```sh +pip3 install -e '.[dev]' +pre-commit install +``` + +For any changes, please submit a PR. \ No newline at end of file diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 754677d..60e7b86 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -7,7 +7,6 @@ import uuid from collections import defaultdict from datetime import timedelta -import warnings import numpy as np import pufferlib @@ -145,7 +144,7 @@ def print_dashboard(stats, init_performance, performance): elif "time" in k: try: v = f"{v:.2f} s" - except: + except: # noqa pass first_word, *rest_words = k.split("_") @@ -343,7 +342,6 @@ def evaluate(self): self.log = False self.policy_pool.update_policies() - performance = defaultdict(list) env_profiler = pufferlib.utils.Profiler() inference_profiler = pufferlib.utils.Profiler() eval_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() @@ -492,7 +490,7 @@ def evaluate(self): self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) self.stats[k] = np.mean(v) self.max_stats[k] = np.max(v) - except: + except: # noqa continue if config.verbose: @@ -549,7 +547,7 @@ def train(self): self.b_obs = b_obs = torch.Tensor(self.obs_ary[b_idxs]) b_actions = torch.Tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) b_logprobs = torch.Tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) - b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) + # b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) b_values = torch.Tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) b_advantages = advantages.reshape( config.batch_rows, num_minibatches, config.bptt_horizon diff --git a/pokemonred_puffer/global_map.py b/pokemonred_puffer/global_map.py index dd17a46..a8b9118 100644 --- a/pokemonred_puffer/global_map.py +++ b/pokemonred_puffer/global_map.py @@ -1,18 +1,22 @@ import os import json -MAP_PATH = os.path.join(os.path.dirname(__file__), 'map_data.json') +MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json") GLOBAL_MAP_SHAPE = (444, 436) with open(MAP_PATH) as map_data: - MAP_DATA = json.load(map_data)['regions'] -MAP_DATA = {int(e['id']): e for e in MAP_DATA} + MAP_DATA = json.load(map_data)["regions"] +MAP_DATA = {int(e["id"]): e for e in MAP_DATA} + # Handle KeyErrors def local_to_global(r: int, c: int, map_n: int): try: - map_x, map_y,= MAP_DATA[map_n]['coordinates'] + ( + map_x, + map_y, + ) = MAP_DATA[map_n]["coordinates"] return r + map_y, c + map_x except KeyError: - print(f'Map id {map_n} not found in map_data.json.') - return r + 0, c + 0 \ No newline at end of file + print(f"Map id {map_n} not found in map_data.json.") + return r + 0, c + 0 diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index b3c168b..7f3ada4 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -15,10 +15,12 @@ def one_hot(tensor, num_classes): torch.int64 ) + class RecurrentMultiConvolutionalWrapper(pufferlib.models.RecurrentWrapper): def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): super().__init__(env, policy, input_size, hidden_size, num_layers) + class MultiConvolutionalPolicy(pufferlib.models.Policy): def __init__( self, diff --git a/pokemonred_puffer/resnet.py b/pokemonred_puffer/resnet.py index ef8e7b8..cdeeaf2 100644 --- a/pokemonred_puffer/resnet.py +++ b/pokemonred_puffer/resnet.py @@ -166,7 +166,9 @@ def __init__( ) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d( + input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 47f4095..ec1d786 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -204,7 +204,7 @@ def train( "-p", "--policy-name", default="multi_convolutional.MultiConvolutionalPolicy", - help="Policy module to use in policies", + help="Policy module to use in policies.", ) parser.add_argument( "-r", diff --git a/pokemonred_puffer/wrappers/stream_wrapper.py b/pokemonred_puffer/wrappers/stream_wrapper.py index 9590a9e..6fedb87 100644 --- a/pokemonred_puffer/wrappers/stream_wrapper.py +++ b/pokemonred_puffer/wrappers/stream_wrapper.py @@ -58,13 +58,13 @@ async def broadcast_ws_message(self, message): if self.websocket is not None: try: await self.websocket.send(message) - except websockets.exceptions.WebSocketException as e: + except websockets.exceptions.WebSocketException: self.websocket = None async def establish_wc_connection(self): try: self.websocket = await websockets.connect(self.ws_address) - except: + except: # noqa self.websocket = None def reset(self, *args, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index 0c971ac..c57c8e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ monitoring = [ "nvitop" ] dev = [ + "pre-commit", "ruff" ]