-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
33 lines (24 loc) · 1.06 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from parameters import *
from train_a_sequence import run
import helper_methods
from summaries import log_run_settings
create_logs_dir()
log_run_settings()
_, charmap, inv_charmap = helper_methods.load_dataset(seq_length=32, b_lines=False)
REAL_BATCH_SIZE = FLAGS.BATCH_SIZE
stages = range(FLAGS.START_SEQ, FLAGS.END_SEQ)
print('------------------Stages : ' + ' '.join(map(str, stages)) + "--------------")
for i in range(len(stages)):
prev_seq_length = stages[i-1] if i>0 else 0
seq_length = stages[i]
print(
"------------------Training on Seq Len = %d, BATCH SIZE: %d------------------" % (
seq_length, BATCH_SIZE))
tf.reset_default_graph()
if FLAGS.SCHEDULE_ITERATIONS:
iterations = min((seq_length + 1) * FLAGS.SCHEDULE_MULT, FLAGS.ITERATIONS_PER_SEQ_LENGTH)
else:
iterations = FLAGS.ITERATIONS_PER_SEQ_LENGTH
run( iterations, seq_length, seq_length == stages[0] and not (FLAGS.TRAIN_FROM_CKPT), charmap, inv_charmap, prev_seq_length )
if FLAGS.DYNAMIC_BATCH:
BATCH_SIZE = REAL_BATCH_SIZE / seq_length