Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Continued Training in Stable Baseline 3 #597

Closed
2 tasks done
Necropsy opened this issue Oct 6, 2021 · 3 comments · Fixed by #615
Closed
2 tasks done

[Question] Continued Training in Stable Baseline 3 #597

Necropsy opened this issue Oct 6, 2021 · 3 comments · Fixed by #615
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@Necropsy
Copy link

Necropsy commented Oct 6, 2021

Question

The agent does not demonstrate to be learning over time by following a continuing training model. Would it be a problem with this specific training model or with sb3?

More details

I have a custom env that connects to a network simulator, where the size of each simulation is fixed at 500 timesteps, and after this number of steps, I must end the training, save it to restart the simulator, and then restart the training. Using DQN in my env I realized that the agent didn't learn over time, so I decided to test the training model in a gym environment in order to validate the idea of ​​continued training, but I got the same problem.

Code

Continued training:

import gym
import numpy as np

import os
import random as rd

from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor

def create_env():
    env = gym.make("CartPole-v0")
    return env

seeds = []
seeds = rd.sample(range(1,5001), 51)

log_dir = "logs/"
os.makedirs(log_dir, exist_ok=True)

env = create_env()
env = Monitor(env, log_dir+"_test_0")
obs = env.reset()

model = DQN("MlpPolicy", env,
            learning_rate = 1e-3,
            buffer_size = 1000000,
            learning_starts = 500,
            batch_size = 64,
            tau = 1.0,
            gamma = 0.99,
            train_freq = 4,
            gradient_steps = 1,
            optimize_memory_usage = False,
            target_update_interval = 250,
            exploration_fraction = 0.1,
            exploration_initial_eps = 1.0,
            exploration_final_eps = 0.05,
            max_grad_norm = 10,
            policy_kwargs=dict(net_arch=[32, 32]),
            device ='cpu',
            verbose=1
            )

#CartPole has 200 max steps
model.learn(total_timesteps=200, log_interval=4)

model.save('./logs/dqn_save')
model.save_replay_buffer('./logs/dqn_save_replay_buffer')

del env
del model

# Load and Train
for i in range(50):
    env = create_env()
    env = Monitor(env, log_dir+"_test_"+str(i+1))
    
    model = DQN.load("logs/dqn_save")
    model.load_replay_buffer('logs/dqn_save_replay_buffer')
    
    model.set_env(env)
    model.set_random_seed(seeds[i])
    
    model.learn(total_timesteps=200, log_interval=4)

    model.save('logs/dqn_save')
    model.save_replay_buffer('logs/dqn_save_replay_buffer')
    
    del model
    del env

Non-Continued training:

log_dir = "logs/"
os.makedirs(log_dir, exist_ok=True)

env = create_env()
env = Monitor(env, log_dir+"_test_0")
obs = env.reset()

model = DQN("MlpPolicy", env,
            learning_rate = 1e-3,
            buffer_size = 1000000,
            learning_starts = 500,
            batch_size = 64,
            tau = 1.0,
            gamma = 0.99,
            train_freq = 4,
            gradient_steps = 1,
            optimize_memory_usage = False,
            target_update_interval = 250,
            exploration_fraction = 0.1,
            exploration_initial_eps = 1.0,
            exploration_final_eps = 0.05,
            max_grad_norm = 10,
            policy_kwargs=dict(net_arch=[32, 32]),
            device ='cpu',
            verbose=1
            )

model.learn(total_timesteps=10000, log_interval=4)

model.save('./logs/dqn_save')
model.save_replay_buffer('./logs/dqn_save_replay_buffer')

del env
del model

ag1_50ep200st
ag1_10000st

  • I used the load_parameters and get_parameters but to no avail (it didn't change the results)

Checklist

@Necropsy Necropsy added the question Further information is requested label Oct 6, 2021
@araffin
Copy link
Member

araffin commented Oct 6, 2021

Hello,
First of all, this type of training is very special...

the learning in your case fail for a very special reason:
learning_starts and target_update_interval are above each budget total_timesteps for individual call for .learn() so you agent is just behaving randomly and never updating the target network (and because reset_num_timesteps=True).

For this special case, a simple fix is to replace

by if False (so it does not reset the counter but still reset the environment).
We may need a fix for SB3 but it seems to be a very specific case though...

EDIT: there is probably a simpler fix:

model._last_obs = None
model.learn(total_timesteps=200, log_interval=4, reset_num_timesteps=False)

I think we can do that internally in SB3 when the user call set_env()

@araffin araffin added the bug Something isn't working label Oct 6, 2021
@araffin araffin self-assigned this Oct 16, 2021
@araffin
Copy link
Member

araffin commented Oct 18, 2021

An additional note, for DQN exploration_fraction = 0.1, is not the same in both cases, as it uses total_timesteps for the schedule.

EDIT: and settting periodically the seed is also different than training for a long period as here the number of steps per episode vary (there is a max limit but the number of steps is not constant)

@Necropsy
Copy link
Author

Thanks @araffin,

Sorry for the delay in replying, I've been working on other projects for the past few days.

I tested the following change as recommended:

model._last_obs = None
model.learn (total_timesteps = 200, log_interval = 4, reset_num_timesteps = False)

It works (tested for CartPole-v0 environment). In my scenario I haven't been successful yet (it still behaves randomly when performing continued training), my environment differs from CartPole-v0, as each episode has this value of 200 steps fixed, that is, the environment always only returns done after the 200 episodes.

I will continue exploring!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants