-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model Training #20
Model Training #20
Conversation
src/train.py
Outdated
def get_train_args(): | ||
train_params = argparse.ArgumentParser() | ||
|
||
# Environment parameters | ||
train_params.add_argument("--env", type=str, default='BreakoutDeterministic-v4', help="Environment to use (must have RGB image state space and discrete action space)") | ||
train_params.add_argument("--render", type=bool, default=False, help="Whether or not to display the environment on the screen during training") | ||
train_params.add_argument("--random_seed", type=int, default=1234, help="Random seed for reproducability") | ||
train_params.add_argument("--frame_width", type=int, default=105, help="Frame width after resize.") | ||
train_params.add_argument("--frame_height", type=int, default=80, help="Frame height after resize.") | ||
train_params.add_argument("--frames_per_state", type=int, default=4, help="Sequence of frames which constitutes a single state.") | ||
|
||
# Training parameters | ||
train_params.add_argument("--num_steps_train", type=int, default=50000000, help="Number of steps to train for") | ||
train_params.add_argument("--train_frequency", type=int, default=4, help="Perform training step every N game steps.") | ||
train_params.add_argument("--max_ep_steps", type=int, default=2000, help="Maximum number of steps per episode") | ||
train_params.add_argument("--batch_size", type=int, default=32) | ||
train_params.add_argument("--learning_rate", type=float, default=0.00025) | ||
train_params.add_argument("--replay_mem_size", type=int, default=1000000, help="Maximum size of replay memory buffer") | ||
train_params.add_argument("--initial_replay_mem_size", type=int, default=50000, help="Initial size of replay memory (populated by random actions) before learning can start") | ||
train_params.add_argument("--epsilon_start", type=float, default=1.0, help="Exploration rate at the beginning of training.") | ||
train_params.add_argument("--epsilon_end", type=float, default=0.1, help="Exploration rate at the end of decay.") | ||
train_params.add_argument("--epsilon_step_end", type=int, default=1000000, help="After how many steps to stop decaying the exploration rate.") | ||
train_params.add_argument("--discount_rate", type=float, default=0.99, help="Discount rate (gamma) for future rewards.") | ||
train_params.add_argument("--update_target_step", type=float, default=10000, help="Copy current network parameters to target network every N steps.") | ||
train_params.add_argument("--save_ckpt_step", type=float, default=250000, help="Save checkpoint every N steps") | ||
train_params.add_argument("--save_log_step", type=int, default=1000, help="Save logs every N steps") | ||
|
||
# Files/directories | ||
train_params.add_argument("--ckpt_dir", type=str, default='./ckpts', help="Directory for saving/loading checkpoints") | ||
train_params.add_argument("--ckpt_file", type=str, default=None, help="Checkpoint file to load and resume training from (if None, train from scratch)") | ||
train_params.add_argument("--log_dir", type=str, default='./logs/train', help="Directory for saving logs") | ||
|
||
return train_params.parse_args() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might not need all of these and some of the defaults might not make sense (env, frame width, frame height)
from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope) | ||
to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope) | ||
|
||
op_holder = [] | ||
|
||
# Update old network parameters with new network parameters | ||
for from_var,to_var in zip(from_vars,to_vars): | ||
op_holder.append(to_var.assign(from_var)) | ||
|
||
return op_holder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming this will need to be compatible with @rithvikb model code, but this is fine as a placeholder
state_ph = tf.placeholder(tf.uint8, (None, args.frame_height, args.frame_width, args.frames_per_state)) | ||
action_ph = tf.placeholder(tf.int32, (None)) | ||
target_ph = tf.placeholder(tf.float32, (None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might not be tf.placeholders but again, not clear until @rithvikb creates a PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, fine as a placeholder though for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applies in multiple places below but that is fine
if __name__ == '__main__': | ||
train_args = get_train_args() | ||
train(train_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might want to be changed (rather than command line arguments, this could be a Trainer class with all the above methods and a constructor that takes in the relevant arguments). The only reason I would recommend this is that it would make it easier to test it.
Make sure to add tests for training as well |
No description provided.