Skip to content

Commit

Permalink
Move env id calculation to the stream wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 25, 2024
1 parent 467df24 commit 8eb0d5c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
17 changes: 0 additions & 17 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import random
from collections import deque
from multiprocessing import Lock, shared_memory
from pathlib import Path
from typing import Any, Iterable, Optional
import uuid
Expand Down Expand Up @@ -76,9 +75,6 @@

# TODO: Make global map usage a configuration parameter
class RedGymEnv(Env):
env_id = shared_memory.SharedMemory(create=True, size=4)
lock = Lock()

def __init__(self, env_config: pufferlib.namespace):
# TODO: Dont use pufferlib.namespace. It seems to confuse __init__
self.video_dir = Path(env_config.video_dir)
Expand Down Expand Up @@ -194,19 +190,6 @@ def __init__(self, env_config: pufferlib.namespace):
self.screen = self.pyboy.screen

self.first = True
with RedGymEnv.lock:
env_id = (
(int(RedGymEnv.env_id.buf[0]) << 24)
+ (int(RedGymEnv.env_id.buf[1]) << 16)
+ (int(RedGymEnv.env_id.buf[2]) << 8)
+ (int(RedGymEnv.env_id.buf[3]))
)
self.env_id = env_id
env_id += 1
RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF
RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF
RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF
RedGymEnv.env_id.buf[3] = (env_id) & 0xFF

self.init_mem()

Expand Down
18 changes: 18 additions & 0 deletions pokemonred_puffer/wrappers/stream_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
from multiprocessing import Lock, shared_memory

import gymnasium as gym
import websockets
Expand All @@ -9,8 +10,25 @@


class StreamWrapper(gym.Wrapper):
env_id = shared_memory.SharedMemory(create=True, size=4)
lock = Lock()

def __init__(self, env: RedGymEnv, config: pufferlib.namespace):
super().__init__(env)
with RedGymEnv.lock:
env_id = (
(int(RedGymEnv.env_id.buf[0]) << 24)
+ (int(RedGymEnv.env_id.buf[1]) << 16)
+ (int(RedGymEnv.env_id.buf[2]) << 8)
+ (int(RedGymEnv.env_id.buf[3]))
)
self.env_id = env_id
env_id += 1
RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF
RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF
RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF
RedGymEnv.env_id.buf[3] = (env_id) & 0xFF

self.user = config.user
self.ws_address = "wss://transdimensional.xyz/broadcast"
self.stream_metadata = {
Expand Down

0 comments on commit 8eb0d5c

Please sign in to comment.