diff --git a/quick_train.py b/quick_train.py index 16e418140..62dcfe885 100644 --- a/quick_train.py +++ b/quick_train.py @@ -27,7 +27,6 @@ file_pattern = "./data/train_data/quick_train_data/tfrecords/quick_train_data*.tfrecords" anchors = utils.get_anchors('./data/yolo_anchors.txt') -is_training = tf.placeholder(dtype=tf.bool, name="phase_train") dataset = tf.data.TFRecordDataset(filenames = tf.gfile.Glob(file_pattern)) dataset = dataset.map(utils.parser(anchors, num_classes).parser_example, num_parallel_calls = 10) dataset = dataset.repeat().shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE).prefetch(BATCH_SIZE) @@ -36,35 +35,18 @@ images, *y_true = example model = yolov3.yolov3(num_classes) with tf.variable_scope('yolov3'): - y_pred = model.forward(images, is_training=is_training) + y_pred = model.forward(images, is_training=False) loss = model.compute_loss(y_pred, y_true) y_pred = model.predict(y_pred) - - -# # train -# optimizer = tf.train.AdamOptimizer(LR) -# train_op = optimizer.minimize(loss[0]) -# sess.run(tf.global_variables_initializer()) -# for epoch in range(EPOCHS): - # run_items = sess.run([train_op, y_pred, y_true] + loss, feed_dict={is_training:True}) - # rec, prec, mAP = utils.evaluate(run_items[1], run_items[2], num_classes) - - # print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" - # "\trec:%.2f\tprec:%.2f\tmAP:%.2f" - # %(epoch, run_items[3], run_items[4], run_items[5], run_items[6], run_items[7], rec, prec, mAP)) - - - -# test -load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path) -sess.run(load_ops) + load_ops = utils.load_weights(tf.global_variables(scope='yolov3'), weights_path) + sess.run(load_ops) for epoch in range(EPOCHS): - run_items = sess.run([y_pred, y_true] + loss, feed_dict={is_training:False}) + run_items = sess.run([y_pred, y_true] + loss) rec, prec, mAP = utils.evaluate(run_items[0], run_items[1], num_classes, score_thresh=0.3, iou_thresh=0.5) print("=> EPOCH: %2d\ttotal_loss:%7.4f\tloss_coord:%7.4f\tloss_sizes:%7.4f\tloss_confs:%7.4f\tloss_class:%7.4f" - "\trec:%.2f\tprec:%.2f\tmAP:%.2f" + "\trec:%7.4f\tprec:%7.4f\tmAP:%7.4f" %(epoch, run_items[2], run_items[3], run_items[4], run_items[5], run_items[6], rec, prec, mAP))