Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Training #20

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
'''
## Train ##
# Code to train Deep Q Network on OpenAI Gym environments
@author: Mark Sinton ([email protected])
'''

import os
import sys
import argparse
import gym
import tensorflow as tf
import numpy as np
import time
import random

from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

from utils.utils import preprocess_image, reset_env_and_state_buffer
from utils.experience_replay import ReplayMemory
from utils.state_buffer import StateBuffer
from utils.network import DeepQNetwork

def get_train_args(args=None):
train_params = argparse.ArgumentParser()

# Environment parameters
train_params.add_argument("--env", type=str, default='SuperMarioBros-1-1-v0', help="Environment to use (must have RGB image state space and discrete action space)")
train_params.add_argument("--render", type=bool, default=False, help="Whether or not to display the environment on the screen during training")
train_params.add_argument("--random_seed", type=int, default=1234, help="Random seed for reproducability")
train_params.add_argument("--frame_width", type=int, default=105, help="Frame width after resize.")
train_params.add_argument("--frame_height", type=int, default=80, help="Frame height after resize.")
train_params.add_argument("--frames_per_state", type=int, default=4, help="Sequence of frames which constitutes a single state.")

# Training parameters
train_params.add_argument("--num_steps_train", type=int, default=50000000, help="Number of steps to train for")
train_params.add_argument("--train_frequency", type=int, default=4, help="Perform training step every N game steps.")
train_params.add_argument("--max_ep_steps", type=int, default=2000, help="Maximum number of steps per episode")
train_params.add_argument("--batch_size", type=int, default=32)
train_params.add_argument("--learning_rate", type=float, default=0.00025)
train_params.add_argument("--replay_mem_size", type=int, default=1000000, help="Maximum size of replay memory buffer")
train_params.add_argument("--initial_replay_mem_size", type=int, default=50000, help="Initial size of replay memory (populated by random actions) before learning can start")
train_params.add_argument("--epsilon_start", type=float, default=1.0, help="Exploration rate at the beginning of training.")
train_params.add_argument("--epsilon_end", type=float, default=0.1, help="Exploration rate at the end of decay.")
train_params.add_argument("--epsilon_step_end", type=int, default=1000000, help="After how many steps to stop decaying the exploration rate.")
train_params.add_argument("--discount_rate", type=float, default=0.99, help="Discount rate (gamma) for future rewards.")
train_params.add_argument("--update_target_step", type=float, default=10000, help="Copy current network parameters to target network every N steps.")
train_params.add_argument("--save_ckpt_step", type=float, default=250000, help="Save checkpoint every N steps")
train_params.add_argument("--save_log_step", type=int, default=1000, help="Save logs every N steps")

# Files/directories
train_params.add_argument("--ckpt_dir", type=str, default='./ckpts', help="Directory for saving/loading checkpoints")
train_params.add_argument("--ckpt_file", type=str, default=None, help="Checkpoint file to load and resume training from (if None, train from scratch)")
train_params.add_argument("--log_dir", type=str, default='./logs/train', help="Directory for saving logs")

return train_params.parse_args(args)


def train(args):

# Function to return exploration rate based on current step
def exploration_rate(current_step, exp_rate_start, exp_rate_end, exp_step_end):
if current_step < exp_step_end:
exploration_rate = current_step * ((exp_rate_end-exp_rate_start)/(float(exp_step_end))) + 1
else:
exploration_rate = exp_rate_end

return exploration_rate

# Function to update target network parameters with main network parameters
def update_target_network(from_scope, to_scope):
from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)
to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope)

op_holder = []

# Update old network parameters with new network parameters
for from_var,to_var in zip(from_vars,to_vars):
op_holder.append(to_var.assign(from_var))

return op_holder
Comment on lines +73 to +82
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this will need to be compatible with @rithvikb model code, but this is fine as a placeholder



# Create environment
env = gym_super_mario_bros.make(args.env)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
num_actions = env.action_space.n

# Initialise replay memory and state buffer
replay_mem = ReplayMemory(args)
state_buf = StateBuffer(args)

# Define input placeholders
state_ph = tf.placeholder(tf.uint8, (None, args.frame_height, args.frame_width, args.frames_per_state))
action_ph = tf.placeholder(tf.int32, (None))
target_ph = tf.placeholder(tf.float32, (None))
Comment on lines +95 to +97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not be tf.placeholders but again, not clear until @rithvikb creates a PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, fine as a placeholder though for this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applies in multiple places below but that is fine


# Instantiate DQN network
DQN = DeepQNetwork(num_actions, state_ph, action_ph, target_ph, args.learning_rate, scope='DQN_main') # Note: One scope cannot be the prefix of another scope (e.g. cannot name this scope 'DQN' and
# target network scope 'DQN_target', as a search for vars in 'DQN' scope will return both networks' vars)
DQN_predict_op = DQN.predict()
DQN_train_step_op = DQN.train_step()

# Instantiate DQN target network
DQN_target = DeepQNetwork(num_actions, state_ph, scope='DQN_target')

update_target_op = update_target_network('DQN_main', 'DQN_target')

# Create session
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Add summaries for Tensorboard visualisation
tf.summary.scalar('Loss', DQN.loss)
reward_var = tf.Variable(0.0, trainable=False)
tf.summary.scalar("Episode_Reward", reward_var)
epsilon_var = tf.Variable(args.epsilon_start, trainable=False)
tf.summary.scalar("Exploration_Rate", epsilon_var)
summary_op = tf.summary.merge_all()

# Define saver for saving model ckpts
model_name = 'model.ckpt'
checkpoint_path = os.path.join(args.ckpt_dir, model_name)
if not os.path.exists(args.ckpt_dir):
os.makedirs(args.ckpt_dir)
saver = tf.train.Saver(max_to_keep=201)

# Create summary writer to write summaries to disk
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

# Load ckpt file if given
if args.ckpt_file is not None:
loader = tf.train.Saver() #Restore all variables from ckpt
ckpt = args.ckpt_dir + '/' + args.ckpt_file
ckpt_split = ckpt.split('-')
step_str = ckpt_split[-1]
start_step = int(step_str)
loader.restore(sess, ckpt)
else:
start_step = 0
sess.run(tf.global_variables_initializer())
sess.run(update_target_op)


## Begin training

env.reset()

ep_steps = 0
episode_reward = 0
episode_rewards = []
duration_values = []

# Initially populate replay memory by taking random actions
sys.stdout.write('\nPopulating replay memory with random actions...\n')
sys.stdout.flush()

for random_step in range(1, args.initial_replay_mem_size+1):

if args.render:
env.render()
else:
env.render(mode='rgb_array')

action = env.action_space.sample()
frame, reward, terminal, _ = env.step(action)
frame = preprocess_image(frame, args.frame_width, args.frame_height)
replay_mem.add(action, reward, frame, terminal)

if terminal:
env.reset()

sys.stdout.write('\x1b[2K\rStep {:d}/{:d}'.format(random_step, args.initial_replay_mem_size))
sys.stdout.flush()

# Begin training process
reset_env_and_state_buffer(env, state_buf, args)
sys.stdout.write('\n\nTraining...\n\n')
sys.stdout.flush()

for train_step in range(start_step+1, args.num_steps_train+1):
start_time = time.time()
# Run 'train_frequency' iterations in the game for every training step
for _ in range(0, args.train_frequency):
ep_steps += 1

if args.render:
env.render()
else:
env.render(mode='rgb_array')

# Use an epsilon-greedy policy to select action
epsilon = exploration_rate(train_step, args.epsilon_start, args.epsilon_end, args.epsilon_step_end)
if random.random() < epsilon:
#print("random :(")
#Choose random action
action = env.action_space.sample()
else:
#print("greedy :)")
#Choose action with highest Q-value according to network's current policy
current_state = np.expand_dims(state_buf.get_state(), 0)
action = sess.run(DQN_predict_op, {state_ph:current_state})[0]

# Take action and store experience
#print(action)
frame, reward, terminal, _ = env.step(action)

frame = preprocess_image(frame, args.frame_width, args.frame_height)
state_buf.add(frame)
replay_mem.add(action, reward, frame, terminal)
episode_reward += reward

if terminal or ep_steps == args.max_ep_steps:
# Collect total reward of episode
episode_rewards.append(episode_reward)
# Reset episode reward and episode steps counters
episode_reward = 0
ep_steps = 0
# Reset environment and state buffer for next episode
reset_env_and_state_buffer(env, state_buf, args)

## Training step
# Get minibatch from replay mem
states_batch, actions_batch, rewards_batch, next_states_batch, terminals_batch = replay_mem.getMinibatch()
# Calculate target by passing next states through the target network and finding max future Q
future_Q = sess.run(DQN_target.output, {state_ph:next_states_batch})
max_future_Q = np.max(future_Q, axis=1)
# Q values of the terminal states is 0 by definition
max_future_Q[terminals_batch] = 0
targets = rewards_batch + (max_future_Q*args.discount_rate)

# Execute training step
if train_step % args.save_log_step == 0:
# Train and save logs
average_reward = sum(episode_rewards)/len(episode_rewards)
summary_str, _ = sess.run([summary_op, DQN_train_step_op], {state_ph:states_batch, action_ph:actions_batch, target_ph:targets, reward_var: average_reward, epsilon_var: epsilon})
summary_writer.add_summary(summary_str, train_step)
# Reset rewards buffer
episode_rewards = []
else:
# Just train
_ = sess.run(DQN_train_step_op, {state_ph:states_batch, action_ph:actions_batch, target_ph:targets})

# Update target networks
if train_step % args.update_target_step == 0:
sess.run(update_target_op)

# Calculate time per step and display progress to console
duration = time.time() - start_time
duration_values.append(duration)
ave_duration = sum(duration_values)/float(len(duration_values))

sys.stdout.write('\x1b[2K\rStep {:d}/{:d} \t ({:.3f} s/step)'.format(train_step, args.num_steps_train, ave_duration))
sys.stdout.flush()

# Save checkpoint
if train_step % args.save_ckpt_step == 0:
saver.save(sess, checkpoint_path, global_step=train_step)
sys.stdout.write('\n Checkpoint saved\n')
sys.stdout.flush()

# Reset time calculation
duration_values = []



if __name__ == '__main__':
train_args = get_train_args()
train(train_args)
Comment on lines +271 to +273
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might want to be changed (rather than command line arguments, this could be a Trainer class with all the above methods and a constructor that takes in the relevant arguments). The only reason I would recommend this is that it would make it easier to test it.

19 changes: 19 additions & 0 deletions test/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

from src.Model import Model
from src.train import Train

def test_train_init():
model = Model(100,50)
trainer = Train(env='SuperMarioBros-1-1-v0')
assert(trainer.env == 'SuperMarioBros-1-1-v0')
assert(trainer.frame_width == 240)
assert(trainer.frame_height == 256)

def test_train_train():
model = Model(100,50)
trainer = Train(env='SuperMarioBros-1-1-v0')
trainer.train()
assert(os.path.exists('./ckpts'))