Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Training #20

Closed
wants to merge 8 commits into from
Closed

Model Training #20

wants to merge 8 commits into from

Conversation

tcliff30
Copy link
Collaborator

No description provided.

@tcliff30 tcliff30 self-assigned this Oct 11, 2021
@tcliff30 tcliff30 linked an issue Oct 11, 2021 that may be closed by this pull request
@tcliff30 tcliff30 changed the title not supposed to work: just a base for linking issue Model Training Oct 15, 2021
src/train.py Outdated
Comment on lines 20 to 52
def get_train_args():
train_params = argparse.ArgumentParser()

# Environment parameters
train_params.add_argument("--env", type=str, default='BreakoutDeterministic-v4', help="Environment to use (must have RGB image state space and discrete action space)")
train_params.add_argument("--render", type=bool, default=False, help="Whether or not to display the environment on the screen during training")
train_params.add_argument("--random_seed", type=int, default=1234, help="Random seed for reproducability")
train_params.add_argument("--frame_width", type=int, default=105, help="Frame width after resize.")
train_params.add_argument("--frame_height", type=int, default=80, help="Frame height after resize.")
train_params.add_argument("--frames_per_state", type=int, default=4, help="Sequence of frames which constitutes a single state.")

# Training parameters
train_params.add_argument("--num_steps_train", type=int, default=50000000, help="Number of steps to train for")
train_params.add_argument("--train_frequency", type=int, default=4, help="Perform training step every N game steps.")
train_params.add_argument("--max_ep_steps", type=int, default=2000, help="Maximum number of steps per episode")
train_params.add_argument("--batch_size", type=int, default=32)
train_params.add_argument("--learning_rate", type=float, default=0.00025)
train_params.add_argument("--replay_mem_size", type=int, default=1000000, help="Maximum size of replay memory buffer")
train_params.add_argument("--initial_replay_mem_size", type=int, default=50000, help="Initial size of replay memory (populated by random actions) before learning can start")
train_params.add_argument("--epsilon_start", type=float, default=1.0, help="Exploration rate at the beginning of training.")
train_params.add_argument("--epsilon_end", type=float, default=0.1, help="Exploration rate at the end of decay.")
train_params.add_argument("--epsilon_step_end", type=int, default=1000000, help="After how many steps to stop decaying the exploration rate.")
train_params.add_argument("--discount_rate", type=float, default=0.99, help="Discount rate (gamma) for future rewards.")
train_params.add_argument("--update_target_step", type=float, default=10000, help="Copy current network parameters to target network every N steps.")
train_params.add_argument("--save_ckpt_step", type=float, default=250000, help="Save checkpoint every N steps")
train_params.add_argument("--save_log_step", type=int, default=1000, help="Save logs every N steps")

# Files/directories
train_params.add_argument("--ckpt_dir", type=str, default='./ckpts', help="Directory for saving/loading checkpoints")
train_params.add_argument("--ckpt_file", type=str, default=None, help="Checkpoint file to load and resume training from (if None, train from scratch)")
train_params.add_argument("--log_dir", type=str, default='./logs/train', help="Directory for saving logs")

return train_params.parse_args()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not need all of these and some of the defaults might not make sense (env, frame width, frame height)

Comment on lines +68 to +77
from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)
to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope)

op_holder = []

# Update old network parameters with new network parameters
for from_var,to_var in zip(from_vars,to_vars):
op_holder.append(to_var.assign(from_var))

return op_holder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this will need to be compatible with @rithvikb model code, but this is fine as a placeholder

Comment on lines +89 to +91
state_ph = tf.placeholder(tf.uint8, (None, args.frame_height, args.frame_width, args.frames_per_state))
action_ph = tf.placeholder(tf.int32, (None))
target_ph = tf.placeholder(tf.float32, (None))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not be tf.placeholders but again, not clear until @rithvikb creates a PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, fine as a placeholder though for this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applies in multiple places below but that is fine

Comment on lines +260 to +262
if __name__ == '__main__':
train_args = get_train_args()
train(train_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might want to be changed (rather than command line arguments, this could be a Trainer class with all the above methods and a constructor that takes in the relevant arguments). The only reason I would recommend this is that it would make it easier to test it.

@sagars729
Copy link
Collaborator

Make sure to add tests for training as well

@rithvikb rithvikb closed this Dec 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Model Trainer- Training the Model (Tim 3)
3 participants