Skip to content

Commit

Permalink
scale the step size for fun
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 7, 2024
1 parent 4e51428 commit 5c7ea67
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ env:
animate_scripts: False
exploration_inc: 1.0
exploration_max: 1.0
max_steps_scaling: 0.2 # every 10 events or items gained, multiply max_steps by 2



Expand Down Expand Up @@ -162,7 +163,7 @@ wrappers:
user: thatguy
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 1
jitter: 1
jitter: 0

fixed_reset_value:
- stream_wrapper.StreamWrapper:
Expand Down
5 changes: 4 additions & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.animate_scripts = env_config.animate_scripts
self.exploration_inc = env_config.exploration_inc
self.exploration_max = env_config.exploration_max
self.max_steps_scaling = env_config.max_steps_scaling
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
Expand Down Expand Up @@ -680,7 +681,9 @@ def step(self, action):

self.step_count += 1
reset = (
self.step_count >= self.max_steps # or
self.step_count
>= self.max_steps
* (len(self.required_events) + len(self.required_items) * self.max_steps_scaling) # or
# self.caught_pokemon[6] == 1 # squirtle
)

Expand Down
12 changes: 7 additions & 5 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,20 @@ def train(
policy = make_policy(vecenv.driver_env, args.policy_name, args)

args.train.env = "Pokemon Red"
with CleanPuffeRL(
trainer = CleanPuffeRL(
exp_name=args.exp_name,
config=args.train,
vecenv=vecenv,
policy=policy,
env_recv_queues=env_recv_queues,
env_send_queues=env_send_queues,
wandb_client=wandb_client,
) as trainer:
while not trainer.done_training():
trainer.evaluate()
trainer.train()
)
while not trainer.done_training():
trainer.evaluate()
trainer.train()

trainer.close()


if __name__ == "__main__":
Expand Down

0 comments on commit 5c7ea67

Please sign in to comment.