Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
phildue committed Jun 14, 2018
1 parent ce95600 commit d499cc6
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@
BATCH_SIZE = 4
MODEL_NAME = 'm2m_lstm'
POSITION = (39.7392, -104.99903)
FEATURES_TRAIN = ['air_temperature', 'humidity']
FEATURES_TRAIN = ['air_temperature']
FEATURES_PREDICT = ['air_temperature']
FILENAMES_TRAIN = ['2016', '2015', '2014', '2013', '2012', '2011', '2010', '2009', '2008']
FILENAMES_VALID = ['2017']
T_TRAIN_H = 7 * 24
T_PRED_D = 3
MASK_VALUE = 999


def train(batch_size=BATCH_SIZE,
Expand Down Expand Up @@ -217,13 +218,13 @@ def train(batch_size=BATCH_SIZE,
help='Name of the model to train. Available:\n ' + str([str(k) for k in models.keys()]),
default=MODEL_NAME)
argparser.add_argument('--log_dir', help='Path to store training output', default=LOG_DIR)
argparser.add_argument('--data_dir', help='Path to read data files', default=DATA_DIR)
argparser.add_argument('--data_dir', help='Path to read data files', default=RADIUS)
argparser.add_argument('--batch_size', help='Size of one Batch', default=BATCH_SIZE)
argparser.add_argument('--n_samples', help='Amount of samples to train', default=None)
args = argparser.parse_args()

train(batch_size=args.batch_size,
log_dir=args.log_dir,
data_dir=args.data_dir,
model_name=args.model_name,
n_samples=args.n_samples)
train()
# train(batch_size=args.batch_size,
# log_dir=args.log_dir,
# data_dir=args.data_dir,
# model_name=args.model_name,
# n_samples=args.n_samples)

0 comments on commit d499cc6

Please sign in to comment.