diff --git a/opts.py b/opts.py index 35daf4c9..07c320e3 100644 --- a/opts.py +++ b/opts.py @@ -1,6 +1,6 @@ from __future__ import print_function import argparse - +import misc.utils as utils def parse_opt(): parser = argparse.ArgumentParser() @@ -216,6 +216,11 @@ def parse_opt(): args.structure_sample_n = args.structure_sample_n or args.train_sample_n + + # Deal with feature things before anything + args.use_fc, args.use_att = utils.if_use_feat(args.caption_model) + if args.use_box: args.att_feat_size = args.att_feat_size + 5 + return args diff --git a/train.py b/train.py index 072de6dd..c1a85ab5 100644 --- a/train.py +++ b/train.py @@ -30,9 +30,6 @@ def add_summary_value(writer, key, value, iteration): writer.add_scalar(key, value, iteration) def train(opt): - # Deal with feature things before anything - opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model) - if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 ################################ # Build dataloader