Skip to content

Commit

Permalink
Remove seeding on reset
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 25, 2024
1 parent 4a3506c commit 095d8fc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
14 changes: 7 additions & 7 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ wandb:

debug:
env:
headless: True
headless: False
stream_wrapper: False
init_state: cut
max_steps: 4
max_steps: 1_000_000
train:
device: cpu
compile: False
compile_mode: default
num_envs: 10
num_envs: 4
envs_per_worker: 1
envs_per_batch: 1
batch_size: 8
envs_per_batch: 4
batch_size: 16
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
Expand All @@ -28,8 +28,8 @@ debug:
env_pool: False
log_frequency: 5000
load_optimizer_state: False
swarm_frequency: 5
swarm_keep_pct: .8
swarm_frequency: 10
swarm_keep_pct: .1

env:
headless: True
Expand Down
6 changes: 3 additions & 3 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.moves_obtained = np.zeros(0xA5, dtype=np.uint8)
self.pokecenters = np.zeros(252, dtype=np.uint8)
# lazy random seed setting
if not seed:
seed = random.randint(0, 4096)
self.pyboy.tick(seed, render=False)
# if not seed:
# seed = random.randint(0, 4096)
# self.pyboy.tick(seed, render=False)
else:
self.reset_count += 1

Expand Down
11 changes: 2 additions & 9 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class BaselineRewardEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

# TODO: make the reward weights configurable
def get_game_state_reward(self):
Expand Down Expand Up @@ -83,10 +84,6 @@ def get_levels_reward(self):


class TeachCutReplicationEnv(BaselineRewardEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
Expand Down Expand Up @@ -117,10 +114,6 @@ def get_game_state_reward(self):


class TeachCutReplicationEnvFork(BaselineRewardEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
Expand Down Expand Up @@ -172,7 +165,7 @@ def get_levels_reward(self):
return 15 + (self.max_level_sum - 15) / 4


class RockTunnelReplicationEnv(TeachCutReplicationEnv):
class RockTunnelReplicationEnv(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"level": self.reward_config["level"] * self.get_levels_reward(),
Expand Down

0 comments on commit 095d8fc

Please sign in to comment.