diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index fb43ac5..530ed50 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -3,7 +3,6 @@ import os import random from collections import deque -from multiprocessing import Lock, shared_memory from pathlib import Path from typing import Any, Iterable, Optional import uuid @@ -76,9 +75,6 @@ # TODO: Make global map usage a configuration parameter class RedGymEnv(Env): - env_id = shared_memory.SharedMemory(create=True, size=4) - lock = Lock() - def __init__(self, env_config: pufferlib.namespace): # TODO: Dont use pufferlib.namespace. It seems to confuse __init__ self.video_dir = Path(env_config.video_dir) @@ -194,19 +190,6 @@ def __init__(self, env_config: pufferlib.namespace): self.screen = self.pyboy.screen self.first = True - with RedGymEnv.lock: - env_id = ( - (int(RedGymEnv.env_id.buf[0]) << 24) - + (int(RedGymEnv.env_id.buf[1]) << 16) - + (int(RedGymEnv.env_id.buf[2]) << 8) - + (int(RedGymEnv.env_id.buf[3])) - ) - self.env_id = env_id - env_id += 1 - RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF - RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF - RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF - RedGymEnv.env_id.buf[3] = (env_id) & 0xFF self.init_mem() diff --git a/pokemonred_puffer/wrappers/stream_wrapper.py b/pokemonred_puffer/wrappers/stream_wrapper.py index dba7872..9fc2d8c 100644 --- a/pokemonred_puffer/wrappers/stream_wrapper.py +++ b/pokemonred_puffer/wrappers/stream_wrapper.py @@ -1,5 +1,6 @@ import asyncio import json +from multiprocessing import Lock, shared_memory import gymnasium as gym import websockets @@ -9,8 +10,25 @@ class StreamWrapper(gym.Wrapper): + env_id = shared_memory.SharedMemory(create=True, size=4) + lock = Lock() + def __init__(self, env: RedGymEnv, config: pufferlib.namespace): super().__init__(env) + with RedGymEnv.lock: + env_id = ( + (int(RedGymEnv.env_id.buf[0]) << 24) + + (int(RedGymEnv.env_id.buf[1]) << 16) + + (int(RedGymEnv.env_id.buf[2]) << 8) + + (int(RedGymEnv.env_id.buf[3])) + ) + self.env_id = env_id + env_id += 1 + RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF + RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF + RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF + RedGymEnv.env_id.buf[3] = (env_id) & 0xFF + self.user = config.user self.ws_address = "wss://transdimensional.xyz/broadcast" self.stream_metadata = {