Skip to content

Commit

Permalink
merge in upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Mar 12, 2024
2 parents 090bbd0 + b5a0c09 commit 0495b7f
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 16 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: tox
on:
push:
pull_request:

jobs:
test:
name: test ${{ matrix.py }} - ${{ matrix.os }}
runs-on: ${{ matrix.os }}-latest
strategy:
fail-fast: false
matrix:
os:
- Ubuntu
- Windows
- MacOs
py:
- "3.10"
steps:
- name: Setup python for test ${{ matrix.py }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v3
- name: Upgrade pip
run: python -m pip install -U pip
- name: Install
run: python -m pip install -e '.[dev]'
- name: Check formatting
run: ruff format pokemonred_puffer
- name: Check lint
run: ruff check pokemonred_puffer
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.2
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,15 @@ To add rewards, add a new class to the `rewards` directory. Then update the `rew

### Adding Policies

To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps.
To add policies, add a new class to the `policies` directory. Then update the `policies` section of `config.yaml`. A policy section is keyed by the class path. It is assumed that a recurrent policy will live in the same module as the policy it wraps.

## Development

This repo uses [pre-commit](https://pre-commit.com/) to enforce formatting and linting. For development, please install this repo with:

```sh
pip3 install -e '.[dev]'
pre-commit install
```

For any changes, please submit a PR.
8 changes: 3 additions & 5 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import uuid
from collections import defaultdict
from datetime import timedelta
import warnings

import numpy as np
import pufferlib
Expand Down Expand Up @@ -145,7 +144,7 @@ def print_dashboard(stats, init_performance, performance):
elif "time" in k:
try:
v = f"{v:.2f} s"
except:
except: # noqa
pass

first_word, *rest_words = k.split("_")
Expand Down Expand Up @@ -343,7 +342,6 @@ def evaluate(self):
self.log = False

self.policy_pool.update_policies()
performance = defaultdict(list)
env_profiler = pufferlib.utils.Profiler()
inference_profiler = pufferlib.utils.Profiler()
eval_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start()
Expand Down Expand Up @@ -492,7 +490,7 @@ def evaluate(self):
self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16)
self.stats[k] = np.mean(v)
self.max_stats[k] = np.max(v)
except:
except: # noqa
continue

if config.verbose:
Expand Down Expand Up @@ -549,7 +547,7 @@ def train(self):
self.b_obs = b_obs = torch.Tensor(self.obs_ary[b_idxs])
b_actions = torch.Tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.Tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.Tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
b_advantages = advantages.reshape(
config.batch_rows, num_minibatches, config.bptt_horizon
Expand Down
16 changes: 10 additions & 6 deletions pokemonred_puffer/global_map.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import json

MAP_PATH = os.path.join(os.path.dirname(__file__), 'map_data.json')
MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json")
GLOBAL_MAP_SHAPE = (444, 436)

with open(MAP_PATH) as map_data:
MAP_DATA = json.load(map_data)['regions']
MAP_DATA = {int(e['id']): e for e in MAP_DATA}
MAP_DATA = json.load(map_data)["regions"]
MAP_DATA = {int(e["id"]): e for e in MAP_DATA}


# Handle KeyErrors
def local_to_global(r: int, c: int, map_n: int):
try:
map_x, map_y,= MAP_DATA[map_n]['coordinates']
(
map_x,
map_y,
) = MAP_DATA[map_n]["coordinates"]
return r + map_y, c + map_x
except KeyError:
print(f'Map id {map_n} not found in map_data.json.')
return r + 0, c + 0
print(f"Map id {map_n} not found in map_data.json.")
return r + 0, c + 0
2 changes: 2 additions & 0 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ def one_hot(tensor, num_classes):
torch.int64
)


class RecurrentMultiConvolutionalWrapper(pufferlib.models.RecurrentWrapper):
def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1):
super().__init__(env, policy, input_size, hidden_size, num_layers)


class MultiConvolutionalPolicy(pufferlib.models.Policy):
def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion pokemonred_puffer/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def __init__(
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(
input_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
Expand Down
2 changes: 1 addition & 1 deletion pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def train(
"-p",
"--policy-name",
default="multi_convolutional.MultiConvolutionalPolicy",
help="Policy module to use in policies",
help="Policy module to use in policies.",
)
parser.add_argument(
"-r",
Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/wrappers/stream_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ async def broadcast_ws_message(self, message):
if self.websocket is not None:
try:
await self.websocket.send(message)
except websockets.exceptions.WebSocketException as e:
except websockets.exceptions.WebSocketException:
self.websocket = None

async def establish_wc_connection(self):
try:
self.websocket = await websockets.connect(self.ws_address)
except:
except: # noqa
self.websocket = None

def reset(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ monitoring = [
"nvitop"
]
dev = [
"pre-commit",
"ruff"
]

Expand Down

0 comments on commit 0495b7f

Please sign in to comment.